Skip to content

Commit 7ba26d6

Browse files
authored
Merge pull request #50 from KempnerInstitute/checkpoint-save-load-race
Fix race in CheckpointManager save/load collectives
2 parents b5dde50 + c36863b commit 7ba26d6

2 files changed

Lines changed: 133 additions & 3 deletions

File tree

kempnerforge/checkpoint/manager.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ def save(
130130
# Cleanup old checkpoints
131131
self._cleanup()
132132

133+
# save() is a collective: non-rank-0 ranks must not return until
134+
# rank-0 has committed train_state.pt, metadata.json, and the
135+
# latest symlink. Without this barrier, post-save hooks or readers
136+
# on other ranks race rank-0's writes (especially on NFS/Lustre).
137+
if dist.is_initialized():
138+
dist.barrier()
139+
133140
def wait(self) -> None:
134141
"""Block until any pending async checkpoint save completes."""
135142
self._async_ckpt.wait()
@@ -178,10 +185,25 @@ def load(
178185
if "optimizer" in dcp_state:
179186
self.optimizer.load_state_dict(dcp_state["optimizer"])
180187

181-
# Load non-distributed state
188+
# Load non-distributed state. On NFS/Lustre, independent stat()
189+
# calls can disagree briefly across ranks; if some ranks enter
190+
# this branch and others don't, the broadcast_object_list below
191+
# hangs. Use a rank-0-authoritative existence check broadcast to
192+
# all ranks so every rank takes the same branch.
182193
train_state_path = ckpt_dir / _TRAIN_STATE_FILE
183-
if train_state_path.exists():
184-
train_state = torch.load(train_state_path, map_location="cpu", weights_only=False)
194+
if dist.is_initialized():
195+
exists_flag = [train_state_path.exists() if self._rank == 0 else False]
196+
dist.broadcast_object_list(exists_flag, src=0)
197+
train_state_exists = bool(exists_flag[0])
198+
else:
199+
train_state_exists = train_state_path.exists()
200+
201+
if train_state_exists:
202+
train_state = (
203+
torch.load(train_state_path, map_location="cpu", weights_only=False)
204+
if self._rank == 0 or not dist.is_initialized()
205+
else None
206+
)
185207

186208
# Broadcast from rank 0 to all ranks
187209
if dist.is_initialized():

tests/distributed/test_checkpoint.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@
77

88
import os
99
import shutil
10+
import time
1011
from pathlib import Path
12+
from unittest.mock import patch
1113

1214
import pytest
1315
import torch
1416
import torch.distributed as dist
1517

18+
from kempnerforge.checkpoint import manager as mgr_mod
1619
from kempnerforge.checkpoint.manager import CheckpointManager
1720
from kempnerforge.config.schema import CheckpointConfig, ModelConfig
1821
from kempnerforge.distributed.parallel import apply_fsdp2
@@ -122,3 +125,108 @@ def test_latest_symlink(self, distributed_env, shared_tmp_dir):
122125
latest = Path(ckpt_dir) / "latest"
123126
assert latest.exists()
124127
assert latest.resolve().name == "step_20"
128+
129+
130+
class TestCheckpointSaveBarrier:
131+
"""save() must synchronize all ranks on the rank-0 metadata writes."""
132+
133+
def test_save_waits_for_rank0_writes(self, distributed_env, shared_tmp_dir):
134+
"""Non-rank-0 must not return from save() before rank 0 finishes.
135+
136+
Without the end-of-save barrier, non-rank-0 returns immediately
137+
after the async DCP dispatch, while rank 0 is still writing
138+
train_state.pt and the latest symlink. We force the race to
139+
be measurable by slowing rank-0's torch.save.
140+
"""
141+
mesh = distributed_env
142+
ckpt_dir = shared_tmp_dir
143+
rank = dist.get_rank()
144+
145+
model = Transformer(SMALL_CONFIG).cuda()
146+
apply_fsdp2(model, mesh)
147+
from kempnerforge.config.schema import OptimizerConfig
148+
149+
opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False))
150+
cfg = CheckpointConfig(dir=ckpt_dir, keep_last_n=2)
151+
mgr = CheckpointManager(cfg, model, opt)
152+
153+
real_torch_save = torch.save
154+
sleep_sec = 0.5
155+
156+
def slow_on_rank0(*args, **kwargs):
157+
if rank == 0:
158+
time.sleep(sleep_sec)
159+
return real_torch_save(*args, **kwargs)
160+
161+
# Barrier before timing so all ranks start save() at roughly the
162+
# same instant; isolates the signal we care about.
163+
dist.barrier()
164+
t0 = time.perf_counter()
165+
with patch.object(mgr_mod.torch, "save", side_effect=slow_on_rank0):
166+
mgr.save(step=1, tokens_seen=100)
167+
elapsed = time.perf_counter() - t0
168+
169+
# Every rank must have waited for rank-0's slow write.
170+
assert elapsed >= 0.4 * sleep_sec, (
171+
f"rank {rank}: save() returned after {elapsed:.3f}s — the "
172+
f"end-of-save barrier is missing. Expected >= {0.4 * sleep_sec:.3f}s."
173+
)
174+
175+
# And every rank observes rank-0's writes afterwards.
176+
step_dir = Path(ckpt_dir) / "step_1"
177+
assert (step_dir / "train_state.pt").exists(), (
178+
f"rank {rank}: train_state.pt not visible after save()"
179+
)
180+
assert (step_dir / "metadata.json").exists(), (
181+
f"rank {rank}: metadata.json not visible after save()"
182+
)
183+
184+
185+
class TestCheckpointLoadDivergentExistence:
186+
"""load() must not hang if ranks disagree about train_state.pt existence.
187+
188+
Simulates attribute-cache skew (NFS/Lustre) by patching Path.exists
189+
so non-rank-0 sees a missing file. Rank-0's answer must be authoritative
190+
via broadcast; otherwise only some ranks enter the torch.load branch
191+
and the subsequent broadcast_object_list deadlocks.
192+
"""
193+
194+
def test_load_does_not_hang_on_divergent_exists(self, distributed_env, shared_tmp_dir):
195+
mesh = distributed_env
196+
ckpt_dir = shared_tmp_dir
197+
rank = dist.get_rank()
198+
199+
model = Transformer(SMALL_CONFIG).cuda()
200+
apply_fsdp2(model, mesh)
201+
from kempnerforge.config.schema import OptimizerConfig
202+
203+
opt = build_optimizer(model, OptimizerConfig(lr=1e-3, fused=False))
204+
cfg = CheckpointConfig(dir=ckpt_dir, keep_last_n=2)
205+
mgr = CheckpointManager(cfg, model, opt)
206+
207+
# Save so there's something to load.
208+
mgr.save(step=1, tokens_seen=100)
209+
dist.barrier()
210+
211+
# Patch exists() so non-rank-0 sees the file as missing. Without
212+
# the authoritative broadcast, non-rank-0 skips torch.load but
213+
# rank 0 enters and calls broadcast_object_list — deadlock.
214+
real_exists = Path.exists
215+
216+
def skewed_exists(self):
217+
if rank != 0 and self.name == "train_state.pt":
218+
return False
219+
return real_exists(self)
220+
221+
# Wrap in a timeout via a CUDA event + sleep loop would be ideal;
222+
# simpler: rely on the PG default timeout to surface a deadlock as
223+
# a RuntimeError rather than blocking the test runner forever.
224+
# With the fix, load() completes promptly; without, the test hangs
225+
# until the PG timeout fires.
226+
with patch.object(Path, "exists", skewed_exists):
227+
step, tokens_seen, _ = mgr.load()
228+
229+
assert step == 1, f"rank {rank}: expected step=1 after load, got {step}"
230+
assert tokens_seen == 100, (
231+
f"rank {rank}: expected tokens_seen=100 after load, got {tokens_seen}"
232+
)

0 commit comments

Comments
 (0)