Skip to content

[feat] support keep_checkpoint_max with async checkpoint pruning#528

Merged
tiankongdeguiji merged 6 commits into
alibaba:masterfrom
tiankongdeguiji:feat/keep-checkpoint-max
May 27, 2026
Merged

[feat] support keep_checkpoint_max with async checkpoint pruning#528
tiankongdeguiji merged 6 commits into
alibaba:masterfrom
tiankongdeguiji:feat/keep-checkpoint-max

Conversation

@tiankongdeguiji

Copy link
Copy Markdown
Collaborator

What

Adds a keep_checkpoint_max training option that caps the number of retained model.ckpt-<step> directories. On long runs checkpoints currently accumulate without bound and can fill the disk; this lets users keep only the most recent N.

  • TrainConfig.keep_checkpoint_max (default 0 = keep all → fully backward compatible; pruning is opt-in).
  • Pruning is fully asynchronous: the training thread only enqueues a coalesced request; a single daemon worker does all filesystem work (directory glob, best-checkpoint lookup, recursive delete), so the loop never blocks on model_dir I/O — important for slow OSS-backed dirs.
  • When export_config.exporter_type == "best", the current best-by-metric checkpoint is always retained even if it falls outside the recent-N window, so best-exporter export never breaks (effective retained count can be N+1).

Design

Checkpoint IO is consolidated into a new CheckpointManager (in checkpoint_util.py) rooted at model_dir, used for both save and load across train/eval/export/predict:

  • save() writes the checkpoint then requests an async prune; close() (called at the end of training) drains pending deletions so the on-disk state is settled before export reads model_dir.
  • Only rank 0 deletes. The queue is bounded by a _prune_pending coalescing flag (at most one in-flight + one pending pass). Deletions are filesystem-agnostic (local shutil.rmtree / fsspec fs.rm(recursive=True)).
  • glob of model_dir is the source of truth every pass (no in-memory registry), so it stays correct under --continue_train.
  • Load/discovery methods (latest_checkpoint, best_checkpoint, restore, restore_dataloader_state) delegate to the existing free functions, which remain canonical. fine_tune_checkpoint discovery stays a free-function call since it resolves a dir outside model_dir.

A partially-written trailing line in the eval-result file (concurrently appended by eval) can make a prune pass raise; this is caught and logged, and the next pass self-heals — no crash, no spurious deletions.

Testing

checkpoint_util_test.py: parameterized prune cases (keep-all / recent-N / protect-best / keep≥count), non-rank-0 no-op, coalesced-idempotent prune, and discovery-delegation. All green; pre-commit run -a clean on changed files.

@tiankongdeguiji tiankongdeguiji added the claude-review Let Claude Review label May 26, 2026
@github-actions github-actions Bot removed the claude-review Let Claude Review label May 26, 2026
tiankongdeguiji and others added 2 commits May 26, 2026 17:05
Make shutil.rmtree fsspec-transparent like the other patched stdlib FS APIs:
on a remote (e.g. OSS) path it routes to fs.rm(recursive=True), otherwise it
delegates to the original rmtree. ignore_errors is honored on both branches.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add TrainConfig.keep_checkpoint_max (default 0 = keep all) so long runs can
cap the number of retained model.ckpt-<step> directories. Documented in
usage/train.md.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the feat/keep-checkpoint-max branch from 5b9b17a to 7099222 Compare May 26, 2026 09:06
Comment thread tzrec/main.py
check_all_workers_data_status=check_all_workers_data_status,
)
model.train()
ckpt_manager.close()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

close() only runs if the training loop completes normally. If any step between the first save() (which starts the daemon worker) and here raises, close() is skipped: the final coalesced prune pass is abandoned and the worker thread is leaked (harmless across process exit since it's a daemon, but it breaks the documented "on-disk state settled before export reads model_dir" contract if anything downstream runs in the same process). Consider wrapping the loop body so close() runs in a finally.

protected = set(ckpt_metas[-self._keep_checkpoint_max :])
if (
self._export_config is not None
and self._export_config.exporter_type == "best"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All FS work here (glob, the eval-file read inside best_checkpoint, and fs.rm) runs on the prune worker thread while the main thread is concurrently doing the next save() — and both go through the same process-global cached fsspec filesystem (filesystem_util._CACHED_FSSPEC_FILESYSTEMS[protocol]). Concurrent calls into one fsspec backend instance from two threads aren't guaranteed safe for all backends (e.g. ossfs runs async ops on a single per-instance loop). The target paths are disjoint (worker only touches older, already-protected dirs), so this is about the shared client object, not the paths. Worth confirming the OSS backend is thread-safe under concurrent access, or giving the worker its own filesystem handle.

Comment thread tzrec/protos/train.proto
// dense gradient clipping config
optional GradClipping grad_clipping = 19;
// maximum number of recent checkpoints to keep; 0 keeps all.
optional uint32 keep_checkpoint_max = 20 [default = 0];

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment reads as a hard cap, but when export_config.exporter_type == "best" the best checkpoint is retained in addition to the N recent ones, so the effective count can be N+1. The class docstring and train.md both note this; the proto comment is the only place that omits it.

Suggested change
optional uint32 keep_checkpoint_max = 20 [default = 0];
// max number of recent checkpoints to keep; 0 keeps all. When
// export_config.exporter_type is "best", the best checkpoint is also retained.
optional uint32 keep_checkpoint_max = 20 [default = 0];

@github-actions

Copy link
Copy Markdown

Review summary

Solid, well-designed change. The async pruning design is the right one: save() does the synchronous collective write and only enqueues a coalesced prune request, while a single rank-0 daemon worker does all the slow model_dir I/O (glob, eval-file read, recursive delete) off the training thread. The queue is provably bounded by the _prune_pending flag, glob is kept as the source of truth (no stale registry under --continue_train), and close() uses a FIFO stop-sentinel so disk state settles before export reads model_dir. The rank-0-only prune introduces no collective mismatch — prune()/close() contain no collective ops and run after the per-step collectives complete on every rank. The self-healing-on-corrupt-eval-line behavior is genuinely fail-safe: best_checkpoint() is called before the deletion loop, so any parse error aborts the whole pass with zero deletions and the next pass recovers. Docstrings are accurate and the test parameterization covers the retention policy well.

A few items worth considering (inline comments on the specifics):

  • close() not exception-safe (main.py:530) — skipped if training raises, leaking the worker / abandoning the final prune.
  • Shared fsspec filesystem across threads (checkpoint_util.py:407) — worth confirming the OSS backend is safe under concurrent worker+main access.
  • Proto comment (train.proto:74) omits the N+1 (recent-N + best) case noted everywhere else.

Test gaps (non-blocking, but the highest-value follow-ups):

  • save() itself is never exercised end-to-end — tests call prune() on hand-made dirs, so the save→prune wiring and the dataloader_state branch are unverified. A subprocess test mirroring the existing _save_restore_worker pattern would close this.
  • The _remove_checkpoint fsspec branch and the corrupt-eval-line self-heal (a PR-advertised behavior) have no coverage; both are cheap to add with mocks.

Minor, pre-existing (while you're here): best_checkpoint() line 229 uses metric.values()[0], which raises TypeError in Py3 (dict_values isn't subscriptable). Not introduced by this PR, but _run_prune now relies on best_checkpoint, so a best-exporter run without an explicit best_exporter_metric would silently abort every prune pass. Suggest list(metric.values())[0].

tiankongdeguiji and others added 2 commits May 26, 2026 17:26
CheckpointManager saves model.ckpt-<step> dirs and prunes old ones to keep
the most recent keep_checkpoint_max. Pruning never blocks the caller: prune()
only enqueues a coalesced request and a single daemon prune worker performs all
filesystem work (glob listing, best-checkpoint lookup, recursive delete via
shutil.rmtree), which matters for slow OSS-backed model_dirs. Only rank 0
deletes. When exporter_type == "best", the current best-by-metric checkpoint is
retained even if outside the recent window.

close() drains the queue at the end of training for a deterministic flush; a
weakref.finalize safety net (registered when the worker starts) runs the same
drain at interpreter exit if close() is skipped (e.g. training raised), so the
worker is never leaked and pending deletions still complete. The class also
exposes load/discovery methods that delegate to the existing free functions.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Construct a CheckpointManager rooted at model_dir and use it for both save and
load across train/eval/export/predict. Training saves now go through
ckpt_manager.save() (which triggers async prune) and close() drains pending
deletions before return; restore and latest/best discovery delegate to the
manager. fine_tune_checkpoint discovery stays a free-function call since it
resolves an external dir outside model_dir.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji force-pushed the feat/keep-checkpoint-max branch from 7099222 to d2c90ff Compare May 26, 2026 09:26
tiankongdeguiji and others added 2 commits May 26, 2026 21:02
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@tiankongdeguiji tiankongdeguiji merged commit f8ac3b3 into alibaba:master May 27, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants