Skip to content

Commit 80bba4b

Browse files
authored
Merge pull request #64 from KempnerInstitute/collective-timeouts
Enforce per-op timeouts on init and coordination collectives
2 parents 11d16e2 + 060345f commit 80bba4b

5 files changed

Lines changed: 348 additions & 9 deletions

File tree

kempnerforge/checkpoint/manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,11 @@ def load(
209209
else None
210210
)
211211

212-
# Broadcast from rank 0 to all ranks
212+
# Broadcast from rank 0 to all ranks. PyTorch 2.11's
213+
# broadcast_object_list does not accept async_op, so a per-op
214+
# timeout cannot be wired here — this call inherits the 1800s
215+
# process-group default. A wedged rank will still surface, just
216+
# later than the other fast-fail paths in this patch.
213217
if dist.is_initialized():
214218
object_list = [train_state if self._rank == 0 else None]
215219
dist.broadcast_object_list(object_list, src=0)

kempnerforge/distributed/setup.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,42 @@ def _set_nccl_env() -> None:
9090
os.environ.setdefault("NCCL_IB_DISABLE", "0")
9191
os.environ.setdefault("NCCL_NET_GDR_LEVEL", "2")
9292

93+
# Ensure NCCL actually enforces the process-group timeout. The default in
94+
# PyTorch 2.2+ is "1", but a user shell/SLURM prolog may override it to
95+
# "0", at which point the PG timeout becomes advisory and stuck collectives
96+
# can hang indefinitely. Set a safe default and warn loudly if the user
97+
# has explicitly disabled it.
98+
existing = os.environ.get("TORCH_NCCL_ASYNC_ERROR_HANDLING")
99+
if existing == "0":
100+
logger.warning(
101+
"TORCH_NCCL_ASYNC_ERROR_HANDLING=0 detected — NCCL timeouts "
102+
"are advisory; stuck collectives can hang indefinitely."
103+
)
104+
else:
105+
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")
106+
107+
108+
def _barrier_with_timeout(seconds: float, reason: str) -> None:
109+
"""dist.barrier with an explicit per-op timeout and a diagnostic log.
110+
111+
The process-group default timeout (``config.nccl_timeout_sec``) is sized
112+
for training collectives (minutes of reduce on large tensors). Init-path
113+
barriers should fail fast so mesh or env misconfiguration does not block
114+
a job for 30 minutes before surfacing a useful error.
115+
"""
116+
work = dist.barrier(async_op=True)
117+
try:
118+
done = work.wait(timeout=timedelta(seconds=seconds)) # type: ignore[reportOptionalMemberAccess]
119+
except RuntimeError as e:
120+
raise RuntimeError(
121+
f"Barrier timed out after {seconds}s during {reason}. "
122+
f"Common causes: MASTER_ADDR/MASTER_PORT disagreement across ranks, "
123+
f"a rank missing from the job, or the IB interface unreachable. "
124+
f"Underlying: {e}"
125+
) from e
126+
if done is False:
127+
raise RuntimeError(f"Barrier timed out after {seconds}s during {reason}.")
128+
93129

94130
def _set_seed(seed: int, rank: int, pp_rank: int = 0) -> None:
95131
"""Set deterministic seeds for reproducibility.
@@ -223,8 +259,10 @@ def init_distributed(config: DistributedConfig, seed: int = 42) -> DeviceMesh |
223259
mesh_dim_names=tuple(mesh_dims),
224260
)
225261

226-
# Ensure all ranks have finished mesh creation before proceeding
227-
dist.barrier()
262+
# Ensure all ranks have finished mesh creation before proceeding.
263+
# A 60s bound fails fast on mesh misconfiguration rather than inheriting
264+
# the 1800s PG timeout.
265+
_barrier_with_timeout(60.0, reason="DeviceMesh construction")
228266

229267
# Set seed (vary by PP rank for different dropout/stochastic depth per stage)
230268
pp_rank = 0

kempnerforge/resilience/health.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import logging
1212
from dataclasses import dataclass, field
13+
from datetime import timedelta
1314

1415
import torch
1516
import torch.distributed as dist
@@ -211,19 +212,36 @@ def check_gpu_health(device: int = 0) -> dict[str, bool | str]:
211212
def check_nccl_health(timeout_sec: float = 10.0) -> bool:
212213
"""Check NCCL communication health via a lightweight all-reduce.
213214
215+
The all-reduce runs with ``async_op=True`` so ``work.wait(timeout=...)``
216+
enforces the caller's bound rather than falling back to the
217+
process-group default timeout (``nccl_timeout_sec``, 1800s). Without
218+
that, this function would sit for 30 minutes on a single stuck peer
219+
regardless of the ``timeout_sec`` argument.
220+
214221
Args:
215-
timeout_sec: Timeout for the collective operation.
222+
timeout_sec: Per-operation timeout for the collective. Returns
223+
False if the all-reduce does not complete within this budget.
216224
217225
Returns:
218-
True if the all-reduce succeeded, False on timeout or error.
226+
True on success, False on timeout, error, or world-size mismatch.
219227
"""
220228
if not dist.is_initialized():
221229
return True # No distributed, nothing to check
222230

223231
try:
224232
tensor = torch.ones(1, device="cuda")
225-
# Use a work handle with timeout
226-
dist.all_reduce(tensor)
233+
work = dist.all_reduce(tensor, async_op=True)
234+
# In current PyTorch Work.wait(timeout=...) raises RuntimeError on
235+
# timeout; older/alternate backends may return False instead. Handle
236+
# both so the timeout is honored regardless of version.
237+
try:
238+
done = work.wait(timeout=timedelta(seconds=timeout_sec)) # type: ignore[reportOptionalMemberAccess]
239+
except RuntimeError as e:
240+
logger.warning(f"NCCL health check timed out after {timeout_sec}s: {e}")
241+
return False
242+
if done is False:
243+
logger.warning(f"NCCL health check timed out after {timeout_sec}s")
244+
return False
227245
torch.cuda.synchronize()
228246
expected = dist.get_world_size()
229247
return abs(tensor.item() - expected) < 1e-5

kempnerforge/training/eval.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,20 @@
77

88
from __future__ import annotations
99

10+
import logging
1011
import math
12+
from datetime import timedelta
1113

1214
import torch
1315
import torch.distributed as dist
1416

17+
logger = logging.getLogger(__name__)
18+
19+
# Per-operation timeout for the PP eval loss broadcast. Shorter than the
20+
# 1800s process-group default so a diverged PP stage surfaces fast rather
21+
# than freezing eval for half an hour.
22+
_EVAL_BROADCAST_TIMEOUT_SEC = 300.0
23+
1524

1625
@torch.no_grad()
1726
def run_eval(
@@ -77,8 +86,29 @@ def run_eval(
7786
avg_loss = 0.0
7887

7988
loss_tensor = torch.tensor([avg_loss], device=device)
80-
dist.broadcast(loss_tensor, group_src=pp_size - 1, group=pp_group) # type: ignore[reportOptionalOperand]
81-
avg_loss = loss_tensor[0].item()
89+
work = dist.broadcast(
90+
loss_tensor,
91+
group_src=pp_size - 1, # type: ignore[reportOptionalOperand]
92+
group=pp_group,
93+
async_op=True,
94+
)
95+
try:
96+
done = work.wait(timeout=timedelta(seconds=_EVAL_BROADCAST_TIMEOUT_SEC)) # type: ignore[reportOptionalMemberAccess]
97+
except RuntimeError as e:
98+
logger.error(
99+
f"Eval loss broadcast timed out after {_EVAL_BROADCAST_TIMEOUT_SEC}s; "
100+
f"a PP stage is likely wedged. Reporting nan loss. Underlying: {e}"
101+
)
102+
avg_loss = float("nan")
103+
else:
104+
if done is False:
105+
logger.error(
106+
f"Eval loss broadcast timed out after {_EVAL_BROADCAST_TIMEOUT_SEC}s; "
107+
"reporting nan loss."
108+
)
109+
avg_loss = float("nan")
110+
else:
111+
avg_loss = loss_tensor[0].item()
82112
else:
83113
# --- Standard eval path ---
84114
total_loss = 0.0

0 commit comments

Comments
 (0)