[feat] support keep_checkpoint_max with async checkpoint pruning#528
Conversation
Make shutil.rmtree fsspec-transparent like the other patched stdlib FS APIs: on a remote (e.g. OSS) path it routes to fs.rm(recursive=True), otherwise it delegates to the original rmtree. ignore_errors is honored on both branches. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add TrainConfig.keep_checkpoint_max (default 0 = keep all) so long runs can cap the number of retained model.ckpt-<step> directories. Documented in usage/train.md. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5b9b17a to
7099222
Compare
| check_all_workers_data_status=check_all_workers_data_status, | ||
| ) | ||
| model.train() | ||
| ckpt_manager.close() |
There was a problem hiding this comment.
close() only runs if the training loop completes normally. If any step between the first save() (which starts the daemon worker) and here raises, close() is skipped: the final coalesced prune pass is abandoned and the worker thread is leaked (harmless across process exit since it's a daemon, but it breaks the documented "on-disk state settled before export reads model_dir" contract if anything downstream runs in the same process). Consider wrapping the loop body so close() runs in a finally.
| protected = set(ckpt_metas[-self._keep_checkpoint_max :]) | ||
| if ( | ||
| self._export_config is not None | ||
| and self._export_config.exporter_type == "best" |
There was a problem hiding this comment.
All FS work here (glob, the eval-file read inside best_checkpoint, and fs.rm) runs on the prune worker thread while the main thread is concurrently doing the next save() — and both go through the same process-global cached fsspec filesystem (filesystem_util._CACHED_FSSPEC_FILESYSTEMS[protocol]). Concurrent calls into one fsspec backend instance from two threads aren't guaranteed safe for all backends (e.g. ossfs runs async ops on a single per-instance loop). The target paths are disjoint (worker only touches older, already-protected dirs), so this is about the shared client object, not the paths. Worth confirming the OSS backend is thread-safe under concurrent access, or giving the worker its own filesystem handle.
| // dense gradient clipping config | ||
| optional GradClipping grad_clipping = 19; | ||
| // maximum number of recent checkpoints to keep; 0 keeps all. | ||
| optional uint32 keep_checkpoint_max = 20 [default = 0]; |
There was a problem hiding this comment.
This comment reads as a hard cap, but when export_config.exporter_type == "best" the best checkpoint is retained in addition to the N recent ones, so the effective count can be N+1. The class docstring and train.md both note this; the proto comment is the only place that omits it.
| optional uint32 keep_checkpoint_max = 20 [default = 0]; | |
| // max number of recent checkpoints to keep; 0 keeps all. When | |
| // export_config.exporter_type is "best", the best checkpoint is also retained. | |
| optional uint32 keep_checkpoint_max = 20 [default = 0]; |
Review summarySolid, well-designed change. The async pruning design is the right one: A few items worth considering (inline comments on the specifics):
Test gaps (non-blocking, but the highest-value follow-ups):
Minor, pre-existing (while you're here): |
CheckpointManager saves model.ckpt-<step> dirs and prunes old ones to keep the most recent keep_checkpoint_max. Pruning never blocks the caller: prune() only enqueues a coalesced request and a single daemon prune worker performs all filesystem work (glob listing, best-checkpoint lookup, recursive delete via shutil.rmtree), which matters for slow OSS-backed model_dirs. Only rank 0 deletes. When exporter_type == "best", the current best-by-metric checkpoint is retained even if outside the recent window. close() drains the queue at the end of training for a deterministic flush; a weakref.finalize safety net (registered when the worker starts) runs the same drain at interpreter exit if close() is skipped (e.g. training raised), so the worker is never leaked and pending deletions still complete. The class also exposes load/discovery methods that delegate to the existing free functions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Construct a CheckpointManager rooted at model_dir and use it for both save and load across train/eval/export/predict. Training saves now go through ckpt_manager.save() (which triggers async prune) and close() drains pending deletions before return; restore and latest/best discovery delegate to the manager. fine_tune_checkpoint discovery stays a free-function call since it resolves an external dir outside model_dir. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
7099222 to
d2c90ff
Compare
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
What
Adds a
keep_checkpoint_maxtraining option that caps the number of retainedmodel.ckpt-<step>directories. On long runs checkpoints currently accumulate without bound and can fill the disk; this lets users keep only the most recent N.TrainConfig.keep_checkpoint_max(default0= keep all → fully backward compatible; pruning is opt-in).glob, best-checkpoint lookup, recursive delete), so the loop never blocks onmodel_dirI/O — important for slow OSS-backed dirs.export_config.exporter_type == "best", the current best-by-metric checkpoint is always retained even if it falls outside the recent-N window, so best-exporter export never breaks (effective retained count can be N+1).Design
Checkpoint IO is consolidated into a new
CheckpointManager(incheckpoint_util.py) rooted atmodel_dir, used for both save and load across train/eval/export/predict:save()writes the checkpoint then requests an async prune;close()(called at the end of training) drains pending deletions so the on-disk state is settled before export readsmodel_dir._prune_pendingcoalescing flag (at most one in-flight + one pending pass). Deletions are filesystem-agnostic (localshutil.rmtree/ fsspecfs.rm(recursive=True)).globofmodel_diris the source of truth every pass (no in-memory registry), so it stays correct under--continue_train.latest_checkpoint,best_checkpoint,restore,restore_dataloader_state) delegate to the existing free functions, which remain canonical.fine_tune_checkpointdiscovery stays a free-function call since it resolves a dir outsidemodel_dir.A partially-written trailing line in the eval-result file (concurrently appended by eval) can make a prune pass raise; this is caught and logged, and the next pass self-heals — no crash, no spurious deletions.
Testing
checkpoint_util_test.py: parameterized prune cases (keep-all / recent-N / protect-best / keep≥count), non-rank-0 no-op, coalesced-idempotent prune, and discovery-delegation. All green;pre-commit run -aclean on changed files.