diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 72b8709529..d28ef7751e 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -56,10 +56,8 @@ export HF_TOKEN= # your token to access gated HF repo # -- MaxText configuration -- export MODEL_CHECKPOINT_DIRECTORY= # e.g., gs://my-bucket/my-checkpoint-directory - # -- storage and format options -export USE_ZARR3= # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways. -export USE_OCDBT= # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways. +export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX. export LAZY_LOAD_TENSORS= # True to use lazy load, False to use eager load. ``` @@ -78,21 +76,18 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \ use_multimodal=false \ hardware=cpu \ skip_jax_distributed_system=true \ - checkpoint_storage_use_zarr3=${USE_ZARR3?} \ - checkpoint_storage_use_ocdbt=${USE_OCDBT?} \ + checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \ + checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \ --lazy_load_tensors=${LAZY_LOAD_TENSORS?} ``` -**Key arguments:** - - `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`. - `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false). - `use_multimodal`: Indicates if multimodality is used, important for Gemma3. - `hf_access_token`: Your Hugging Face token. - `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`. - `hardware=cpu`: run the conversion script on a CPU machine. -- `checkpoint_storage_use_zarr3`: Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways. -- `checkpoint_storage_use_ocdbt`: Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways. +- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above. - `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. - `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. diff --git a/docs/guides/checkpointing_solutions/gcs_checkpointing.md b/docs/guides/checkpointing_solutions/gcs_checkpointing.md index 1ba471492b..9fa4e7192a 100644 --- a/docs/guides/checkpointing_solutions/gcs_checkpointing.md +++ b/docs/guides/checkpointing_solutions/gcs_checkpointing.md @@ -11,39 +11,39 @@ bucket. The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed: -1. **Resume Current Run**: If a checkpoint already exists for the current - `run_name`, the system loads the latest fully-saved checkpoint. This is the - default behavior to ensure minimal state loss when resuming after an - interruption. -2. **Load from Specific Path**: The system checks for a user-specified path. - * If `load_parameters_path` is set, we load a parameter only checkpoint from that path.. - * If `load_full_state_path` is set, we load a full state checkpoint from that path. - * **Note**: These two options are mutually exclusive and will cause an error if both are set. -3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead. +1. **Resume Current Run**: If a checkpoint already exists for the current + `run_name`, the system loads the latest fully-saved checkpoint. This is the + default behavior to ensure minimal state loss when resuming after an + interruption. +2. **Load from Specific Path**: The system checks for a user-specified path. + - If `load_parameters_path` is set, we load a parameter only checkpoint from that path.. + - If `load_full_state_path` is set, we load a full state checkpoint from that path. + - **Note**: These two options are mutually exclusive and will cause an error if both are set. +3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead. ### MaxText configuration -| Flag | Description | Type | Default | -| :--- | :--- | :--- | :--- | -| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | -| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | -| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` | -| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` | -| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) | -| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) | -| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) | -| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` | +| Flag | Description | Type | Default | +| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- | +| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` | +| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` | +| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` | +| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.
**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` | +| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.
**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) | +| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.
**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) | +| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) | +| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` | ## Storage and format configuration These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility. -| Flag | Description | Type | Default | -| :--- | :--- | :--- | :--- | -| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) | -| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. | `boolean` | `True` | -| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. | `boolean` | `True` | -| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` | -| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` | -| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` | -| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` | +| Flag | Description | Type | Default | +| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ | +| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) | +| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` | +| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` | +| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` | +| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` | +| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.
**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` | +| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` | diff --git a/docs/tutorials/posttraining/sft_on_multi_host.md b/docs/tutorials/posttraining/sft_on_multi_host.md index 91dd5d856c..243ae56127 100644 --- a/docs/tutorials/posttraining/sft_on_multi_host.md +++ b/docs/tutorials/posttraining/sft_on_multi_host.md @@ -116,8 +116,11 @@ export MODEL_CHECKPOINT_PATH= # e.g., gs://my-b **Note:** Make sure that `MODEL_CHECKPOINT_PATH` has the checkpoints created using the correct storage flags: -- **For SFT with McJAX:** `checkpoint_storage_use_zarr3=True` and `checkpoint_storage_use_ocdbt=True`. -- **For SFT with Pathways:** `checkpoint_storage_use_zarr3=False` and `checkpoint_storage_use_ocdbt=False`. +``` +export USE_PATHWAYS=0 # Set to 1 for Pathways, 0 for McJAX. +checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) +checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) +``` ### Option 2: Converting a Hugging Face checkpoint @@ -150,6 +153,8 @@ Once the fine-tuning is completed, you can access your model checkpoints at `$OU ### 6.2. SFT with Pathways ```bash +export USE_PATHWAYS=1 + xpk workload create-pathways \ --cluster=${CLUSTER_NAME?} \ --project=${PROJECT?} \ @@ -158,7 +163,7 @@ xpk workload create-pathways \ --workload=${WORKLOAD_NAME?} \ --tpu-type=${TPU_TYPE?} \ --num-slices=${TPU_SLICE?} \ ---command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=False checkpoint_storage_use_ocdbt=False enable_single_controller=True" +--command="JAX_PLATFORMS=proxy JAX_BACKEND_TARGET=grpc://127.0.0.1:29000 ENABLE_PATHWAYS_PERSISTENCE=1 python3 -m maxtext.trainers.post_train.sft.train_sft run_name=${WORKLOAD_NAME?} base_output_directory=${OUTPUT_PATH?} model_name=${MODEL_NAME?} load_parameters_path=${MODEL_CHECKPOINT_PATH?} hf_access_token=${HF_TOKEN?} per_device_batch_size=1 steps=${STEPS?} profiler=xplane checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) enable_single_controller=True" ``` Once the fine-tuning is completed, you can access your model checkpoints at `$OUTPUT_PATH/$WORKLOAD_NAME/checkpoints`. diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 0e8668d612..1e350ca798 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -85,13 +85,14 @@ def get_maxtext_model(config, devices=None): Load MaxText model with Tunix adapter. # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if - # using Pathways, please set `checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False` + # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags: + # export USE_PATHWAYS=1 # python src/MaxText/checkpoint_conversion/to_maxtext.py \ # --model_name="gemma2-2b" \ # --base_output_directory="/path/to/your/output/directory" \ # --scan_layers=True \ - # --checkpoint_storage_use_ocdbt=False\ - # checkpoint_storage_use_zarr3=False + # --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \ + # --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., # load_parameters_path=/path/to/your/output/directory/0/items """