Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `docs/checkpointing/dcp-model.md`: updated the save/load snippets and the "shape to fill" explanation to the DCP-aware helpers.
- Tests (fail on the pre-fix code, pass after): `tests/integration/test_checkpoint_roundtrip.py::test_manager_restores_optimizer_moments_single_gpu` and `tests/distributed/test_checkpoint.py::test_resume_restores_optimizer_moments` assert `exp_avg` / `exp_avg_sq` are restored bit-exactly into a *fresh* optimizer (single-GPU + distributed); `tests/e2e/test_training_e2e.py::test_resume_determinism_single_gpu` / `test_resume_determinism_2gpu_fsdp` assert end-to-end bit-exact loss across an interrupt-and-resume on a learnable dataset.
- **On-disk format note:** optimizer state is now keyed by parameter fully-qualified name rather than positional index. Checkpoints written before this fix will not restore optimizer state on resume (training continues with a fresh optimizer); model state is unaffected.
- **Always save a final checkpoint on clean completion.** A run that finished cleanly previously wrote a final checkpoint only when `max_steps` happened to land on the save schedule; otherwise the fully-trained model (including the entire WSD decay tail) was never persisted and `latest` resolved to the last scheduled step. The training-loop teardown now writes one unconditional checkpoint at `max_steps` on normal completion when that step is off-schedule, matching the emergency-checkpoint guarantee on the preemption path. Purely additive: the schedule/retention knobs (`interval`, `keep_last_n`, `dyn_ckpt_window`) are untouched.
- `scripts/train.py`: a `completed_normally` marker set on the final iteration, plus an off-schedule final `ckpt_mgr.save(...)` (and `on_checkpoint_save` hook) before `ckpt_mgr.wait()`.
- Tests: `tests/e2e/test_training_e2e.py` — a clean off-schedule completion writes `step_{max_steps}` and points `latest` at it.

## [0.1.0] — 2026-04-16

Expand Down
10 changes: 10 additions & 0 deletions docs/training/training-loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,22 @@ After the loop:

```python
prof.stop()
# Clean off-schedule finish: persist the fully-trained final step
if completed_normally and not config.checkpoint.should_save(step):
ckpt_mgr.save(step, ...) # final checkpoint; `latest` committed after wait
ckpt_mgr.wait() # drain last async save
Comment thread
Copilot marked this conversation as resolved.
hook_runner.on_train_end(step, tokens_seen)
tracker.close()
destroy_distributed()
```

On a clean finish, an unconditional checkpoint is written at `max_steps`
when that step is not already on the save schedule — so a completed run's
fully-trained model (including the WSD decay tail) is always recoverable
and `latest` points at it. This mirrors the emergency checkpoint the
preemption path writes on shutdown. The `should_save` guard avoids a
duplicate when `max_steps` already coincided with the schedule.

`ckpt_mgr.wait()` is load-bearing — without it, a rank can exit before
its async DCP write completes, corrupting the checkpoint for
everyone else on the same save. See
Expand Down
20 changes: 20 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
)
hook_runner.on_checkpoint_save(0, config.checkpoint.dir)

completed_normally = False

while step < tc.max_steps:
# Refresh data iterator at start / epoch boundary
if dataloader is not None and data_iter is None:
Expand Down Expand Up @@ -1039,11 +1041,29 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
shutdown_handler.finish()
break

# Clean-completion marker for the unconditional final save after the
# loop. Only reached when training completes without any errors, e.g.,
# no NaN/NCCL/shutdown breaks. If a run encounters a NaN, the last step
# is intentionally *not* saved because the actual model state would be
# `max_steps - 1`, not `max_steps`.
if step >= tc.max_steps:
completed_normally = True
Comment thread
camilobrownpinilla marked this conversation as resolved.

if prof is not None:
prof.stop()
if rank == 0:
print_profiler_summary(prof, trace_dir=config.profiling.trace_dir)

if completed_normally and not config.checkpoint.should_save(step):
ckpt_mgr.save(
step=step,
tokens_seen=tokens_seen,
scheduler=scheduler,
dataloader=dataloader,
extra=ckpt_extra,
)
hook_runner.on_checkpoint_save(step, config.checkpoint.dir)

# Flush any pending async checkpoint before tearing down process group
ckpt_mgr.wait()

Expand Down
31 changes: 31 additions & 0 deletions tests/e2e/test_training_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,37 @@ def test_sigterm_triggers_emergency_checkpoint(tmp_path):
assert "Checkpoint saved" in output, f"Emergency checkpoint was not saved:\n{output[-2000:]}"


@pytest.mark.e2e
def test_final_checkpoint_on_clean_completion(tmp_path):
"""A clean finish at an off-schedule step must still persist a final
checkpoint, with `latest` pointing at it."""
ckpt_dir = tmp_path / "final_ckpt"

# max_steps=10 with interval=1000 => no scheduled save lands (debug.toml
# configures no dyn_ckpt_window), so the only checkpoint must be the
# unconditional final save at step 10.
result = _run_training(
[
DEBUG_CONFIG,
"--train.max_steps=10",
"--metrics.log_interval=5",
f"--checkpoint.dir={ckpt_dir}",
"--checkpoint.interval=1000", # no scheduled checkpoint
],
nproc=1,
)
_assert_training_complete(result, expected_steps=10)

output = result.stdout + result.stderr
final = ckpt_dir / "step_10"
latest = ckpt_dir / "latest"
assert final.is_dir(), f"final checkpoint step_10 was not written:\n{output[-2000:]}"
assert latest.exists(), f"`latest` pointer is missing:\n{output[-2000:]}"
assert latest.resolve().name == "step_10", (
f"`latest` should resolve to step_10, got {latest.resolve().name!r}"
)


# ============================================================================
# MoE Training
# ============================================================================
Expand Down
Loading