|
7 | 7 |
|
8 | 8 | import os |
9 | 9 | import shutil |
| 10 | +import time |
10 | 11 | from pathlib import Path |
| 12 | +from unittest.mock import patch |
11 | 13 |
|
12 | 14 | import pytest |
13 | 15 | import torch |
14 | 16 | import torch.distributed as dist |
15 | 17 |
|
| 18 | +from kempnerforge.checkpoint import manager as mgr_mod |
16 | 19 | from kempnerforge.checkpoint.manager import CheckpointManager |
17 | 20 | from kempnerforge.config.schema import CheckpointConfig, ModelConfig |
18 | 21 | from kempnerforge.distributed.parallel import apply_fsdp2 |
@@ -122,3 +125,108 @@ def test_latest_symlink(self, distributed_env, shared_tmp_dir): |
122 | 125 | latest = Path(ckpt_dir) / "latest" |
123 | 126 | assert latest.exists() |
124 | 127 | assert latest.resolve().name == "step_20" |
| 128 | + |
| 129 | + |
| 130 | +class TestCheckpointSaveBarrier: |
| 131 | + """save() must synchronize all ranks on the rank-0 metadata writes.""" |
| 132 | + |
| 133 | + def test_save_waits_for_rank0_writes(self, distributed_env, shared_tmp_dir): |
| 134 | + """Non-rank-0 must not return from save() before rank 0 finishes. |
| 135 | +
|
| 136 | + Without the end-of-save barrier, non-rank-0 returns immediately |
| 137 | + after the async DCP dispatch, while rank 0 is still writing |
| 138 | + train_state.pt and the latest symlink. We force the race to |
| 139 | + be measurable by slowing rank-0's torch.save. |
| 140 | + """ |
| 141 | + mesh = distributed_env |
| 142 | + ckpt_dir = shared_tmp_dir |
| 143 | + rank = dist.get_rank() |
| 144 | + |
| 145 | + model = Transformer(SMALL_CONFIG).cuda() |
| 146 | + apply_fsdp2(model, mesh) |
| 147 | + from kempnerforge.config.schema import OptimizerConfig |
| 148 | + |
| 149 | + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) |
| 150 | + cfg = CheckpointConfig(dir=ckpt_dir, keep_last_n=2) |
| 151 | + mgr = CheckpointManager(cfg, model, opt) |
| 152 | + |
| 153 | + real_torch_save = torch.save |
| 154 | + sleep_sec = 0.5 |
| 155 | + |
| 156 | + def slow_on_rank0(*args, **kwargs): |
| 157 | + if rank == 0: |
| 158 | + time.sleep(sleep_sec) |
| 159 | + return real_torch_save(*args, **kwargs) |
| 160 | + |
| 161 | + # Barrier before timing so all ranks start save() at roughly the |
| 162 | + # same instant; isolates the signal we care about. |
| 163 | + dist.barrier() |
| 164 | + t0 = time.perf_counter() |
| 165 | + with patch.object(mgr_mod.torch, "save", side_effect=slow_on_rank0): |
| 166 | + mgr.save(step=1, tokens_seen=100) |
| 167 | + elapsed = time.perf_counter() - t0 |
| 168 | + |
| 169 | + # Every rank must have waited for rank-0's slow write. |
| 170 | + assert elapsed >= 0.4 * sleep_sec, ( |
| 171 | + f"rank {rank}: save() returned after {elapsed:.3f}s — the " |
| 172 | + f"end-of-save barrier is missing. Expected >= {0.4 * sleep_sec:.3f}s." |
| 173 | + ) |
| 174 | + |
| 175 | + # And every rank observes rank-0's writes afterwards. |
| 176 | + step_dir = Path(ckpt_dir) / "step_1" |
| 177 | + assert (step_dir / "train_state.pt").exists(), ( |
| 178 | + f"rank {rank}: train_state.pt not visible after save()" |
| 179 | + ) |
| 180 | + assert (step_dir / "metadata.json").exists(), ( |
| 181 | + f"rank {rank}: metadata.json not visible after save()" |
| 182 | + ) |
| 183 | + |
| 184 | + |
| 185 | +class TestCheckpointLoadDivergentExistence: |
| 186 | + """load() must not hang if ranks disagree about train_state.pt existence. |
| 187 | +
|
| 188 | + Simulates attribute-cache skew (NFS/Lustre) by patching Path.exists |
| 189 | + so non-rank-0 sees a missing file. Rank-0's answer must be authoritative |
| 190 | + via broadcast; otherwise only some ranks enter the torch.load branch |
| 191 | + and the subsequent broadcast_object_list deadlocks. |
| 192 | + """ |
| 193 | + |
| 194 | + def test_load_does_not_hang_on_divergent_exists(self, distributed_env, shared_tmp_dir): |
| 195 | + mesh = distributed_env |
| 196 | + ckpt_dir = shared_tmp_dir |
| 197 | + rank = dist.get_rank() |
| 198 | + |
| 199 | + model = Transformer(SMALL_CONFIG).cuda() |
| 200 | + apply_fsdp2(model, mesh) |
| 201 | + from kempnerforge.config.schema import OptimizerConfig |
| 202 | + |
| 203 | + opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False)) |
| 204 | + cfg = CheckpointConfig(dir=ckpt_dir, keep_last_n=2) |
| 205 | + mgr = CheckpointManager(cfg, model, opt) |
| 206 | + |
| 207 | + # Save so there's something to load. |
| 208 | + mgr.save(step=1, tokens_seen=100) |
| 209 | + dist.barrier() |
| 210 | + |
| 211 | + # Patch exists() so non-rank-0 sees the file as missing. Without |
| 212 | + # the authoritative broadcast, non-rank-0 skips torch.load but |
| 213 | + # rank 0 enters and calls broadcast_object_list — deadlock. |
| 214 | + real_exists = Path.exists |
| 215 | + |
| 216 | + def skewed_exists(self): |
| 217 | + if rank != 0 and self.name == "train_state.pt": |
| 218 | + return False |
| 219 | + return real_exists(self) |
| 220 | + |
| 221 | + # Wrap in a timeout via a CUDA event + sleep loop would be ideal; |
| 222 | + # simpler: rely on the PG default timeout to surface a deadlock as |
| 223 | + # a RuntimeError rather than blocking the test runner forever. |
| 224 | + # With the fix, load() completes promptly; without, the test hangs |
| 225 | + # until the PG timeout fires. |
| 226 | + with patch.object(Path, "exists", skewed_exists): |
| 227 | + step, tokens_seen, _ = mgr.load() |
| 228 | + |
| 229 | + assert step == 1, f"rank {rank}: expected step=1 after load, got {step}" |
| 230 | + assert tokens_seen == 100, ( |
| 231 | + f"rank {rank}: expected tokens_seen=100 after load, got {tokens_seen}" |
| 232 | + ) |
0 commit comments