Skip to content

Commit c566c17

Browse files
Add checkpoint deletion options to configuration and checkpoint manager
PiperOrigin-RevId: 888785304
1 parent c30ada0 commit c566c17

5 files changed

Lines changed: 52 additions & 32 deletions

File tree

docs/guides/checkpointing_solutions/gcs_checkpointing.md

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,49 @@ bucket.
99

1010
## Checkpoint loading priority
1111

12-
The system follows a specific order when deciding which checkpoint to load at startup. The first valid condition met is the one executed:
13-
14-
1. **Resume Current Run**: If a checkpoint already exists for the current
15-
`run_name`, the system loads the latest fully-saved checkpoint. This is the
16-
default behavior to ensure minimal state loss when resuming after an
17-
interruption.
18-
2. **Load from Specific Path**: The system checks for a user-specified path.
19-
- If `load_parameters_path` is set, we load a parameter only checkpoint from that path..
20-
- If `load_full_state_path` is set, we load a full state checkpoint from that path.
21-
- **Note**: These two options are mutually exclusive and will cause an error if both are set.
22-
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state instead.
12+
The system follows a specific order when deciding which checkpoint to load at
13+
startup. The first valid condition met is the one executed:
14+
15+
1. **Resume Current Run**: If a checkpoint already exists for the current
16+
`run_name`, the system loads the latest fully-saved checkpoint. This is the
17+
default behavior to ensure minimal state loss when resuming after an
18+
interruption.
19+
2. **Load from Specific Path**: The system checks for a user-specified path.
20+
- If `load_parameters_path` is set, we load a parameter only checkpoint
21+
from that path..
22+
- If `load_full_state_path` is set, we load a full state checkpoint from
23+
that path.
24+
- **Note**: These two options are mutually exclusive and will cause an
25+
error if both are set.
26+
3. **Initialize from Scratch**: We don't load a checkpoint and initialize state
27+
instead.
2328

2429
### MaxText configuration
2530

26-
| Flag | Description | Type | Default |
27-
| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :-------------- |
28-
| `enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False` |
29-
| `async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True` |
30-
| `checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000` |
31-
| `enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False` |
32-
| `load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled) |
33-
| `load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled) |
34-
| `lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled) |
35-
| `force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False` |
31+
Flag | Description | Type | Default
32+
:------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | :-------- | :------
33+
`enable_checkpointing` | A master switch to enable (`True`) or disable (`False`) saving checkpoints during the training run. | `boolean` | `False`
34+
`async_checkpointing` | When set to (`True`), this flag makes checkpoint saving asynchronous. The training step is only blocked for the minimal time needed to capture the model's state, and the actual writing to storage happens in a background thread. This is highly recommended for performance. It's enabled by default. | `boolean` | `True`
35+
`checkpoint_period` | The interval, in training steps, for how often a checkpoint is saved. | `integer` | `10000`
36+
`enable_single_replica_ckpt_restoring` | If `True`, one replica reads the checkpoint from storage and then broadcasts it to all other replicas. This can significantly speed up restoration on multi-host systems by reducing redundant reads from storage.<br>**Note**: This feature is only compatible with training jobs that utilize a Distributed Data Parallel (DDP) strategy. | `boolean` | `False`
37+
`checkpoint_todelete_subdir` | Subdirectory to move checkpoints to before deletion. For example: `".todelete"` | `string` | `""`
38+
`checkpoint_todelete_full_path` | Full path to move checkpoints to before deletion. | `string` | `""`
39+
`load_parameters_path` | Specifies a path to a checkpoint directory to load a parameter only checkpoint.<br>**Example**: `"gs://my-bucket/my-previous-run/checkpoints/items/1000"` | `string` | `""` (disabled)
40+
`load_full_state_path` | Specifies a path to a checkpoint directory to load a full checkpoint including optimizer state and step count from a specific directory.<br>**Example**: `"gs://my-bucket/my-interrupted-run/checkpoints/items/500"` | `string` | `""` (disabled)
41+
`lora_input_adapters_path` | Specifies a parent directory containing LoRA (Low-Rank Adaptation) adapters. | `string` | `""` (disabled)
42+
`force_unroll` | If `True`, unrolls the loop when generating a parameter-only checkpoint. | `boolean` | `False`
3643

3744
## Storage and format configuration
3845

39-
These settings control the underlying storage mechanism ([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
40-
41-
| Flag | Description | Type | Default |
42-
| :----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------------------ |
43-
| `checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB) |
44-
| `checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True` |
45-
| `checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True` |
46-
| `checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96` |
47-
| `enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False` |
48-
| `source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"` |
49-
| `checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None` |
46+
These settings control the underlying storage mechanism
47+
([Orbax](https://orbax.readthedocs.io)) for performance and compatibility.
48+
49+
Flag | Description | Type | Default
50+
:----------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------- | :------
51+
`checkpoint_storage_target_data_file_size_bytes` | Sets a target file size for Orbax to chunk large arrays into smaller physical files. This can dramatically speed up loading over a network and in distributed environments. | `integer` | `2147483648` (2 GB)
52+
`checkpoint_storage_use_ocdbt` | If `True`, uses the TensorStore **OCDBT** (Optionally-Cooperative Distributed B+ Tree)) key-value store as the underlying storage format for checkpointing. Set to `0` for Pathways. | `boolean` | `True`
53+
`checkpoint_storage_use_zarr3` | If `True`, uses the Zarr v3 storage format within Orbax, which is optimized for chunked, compressed, N-dimensional arrays. Set to `0` for Pathways. | `boolean` | `True`
54+
`checkpoint_storage_concurrent_gb` | Controls the concurrent I/O limit in gigabytes for the checkpointer. Larger models may require increasing this value to avoid I/O bottlenecks. | `integer` | `96`
55+
`enable_orbax_v1` | A boolean flag to explicitly enable features and behaviors from Orbax version 1. | `boolean` | `False`
56+
`source_checkpoint_layout` | Specifies the format of the checkpoint being **loaded**. This tells the system how to interpret the files at the source path.<br>**Options**: `"orbax"`, `"safetensors"` | `string` | `"orbax"`
57+
`checkpoint_conversion_fn` | A user-defined function to process a loaded checkpoint dictionary into a format that the model can understand. This is essential for loading checkpoints from different frameworks or formats (e.g., converting keys from a Hugging Face SafeTensors file). | `function` or `None` | `None`

src/maxtext/common/checkpointing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def create_orbax_checkpoint_manager(
221221
enable_single_controller: bool = False,
222222
colocated_python_checkpointing: bool = False,
223223
enable_single_replica_ckpt_restoring: bool = False,
224+
todelete_subdir: str | None = None,
225+
todelete_full_path: str | None = None,
224226
):
225227
"""Returns specified Orbax (async or not) CheckpointManager or None if checkpointing is disabled."""
226228
if not enable_checkpointing:
@@ -268,6 +270,8 @@ def create_orbax_checkpoint_manager(
268270
save_decision_policy=save_decision_policy,
269271
preservation_policy=preservation_policy,
270272
async_options=async_options,
273+
todelete_subdir=todelete_subdir,
274+
todelete_full_path=todelete_full_path,
271275
),
272276
logger=orbax_logger,
273277
)

src/maxtext/configs/base.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ max_num_checkpoints_to_keep: None
5959
enable_continuous_checkpointing: False
6060
# enables one replica to read the ckpt then broadcast to the rest
6161
enable_single_replica_ckpt_restoring: False
62+
# Subdirectory to move checkpoints to before deletion. For example: ".todelete"
63+
checkpoint_todelete_subdir: ""
64+
# Full path to move checkpoints to before deletion.
65+
checkpoint_todelete_full_path: ""
6266

6367
force_unroll: False # during generate_param_only_checkpoint should we unroll the loop?
6468

src/maxtext/configs/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ class Checkpointing(BaseModel):
310310
enable_single_replica_ckpt_restoring: bool = Field(
311311
False, description="One replica reads and broadcasts the checkpoint."
312312
)
313+
checkpoint_todelete_subdir: str = Field("", description="Subdirectory to move checkpoints to before deletion.")
314+
checkpoint_todelete_full_path: str = Field("", description="Full path to move checkpoints to before deletion.")
313315
force_unroll: bool = Field(
314316
False,
315317
description="During param-only checkpoint generation, whether to unroll the loop.",

src/maxtext/utils/train_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def create_training_tools(config, model, mesh):
8282
config.enable_single_controller,
8383
config.colocated_python_checkpointing,
8484
config.enable_single_replica_ckpt_restoring,
85+
config.checkpoint_todelete_subdir,
86+
config.checkpoint_todelete_full_path,
8587
)
8688

8789
return init_rng, checkpoint_manager, learning_rate_schedule, tx

0 commit comments

Comments
 (0)