Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion kempnerforge/checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def load(
if train_state_path.exists():
train_state = torch.load(train_state_path, map_location="cpu", weights_only=False)

# Broadcast from rank 0 to all ranks
# Broadcast from rank 0 to all ranks. PyTorch 2.11's
# broadcast_object_list does not accept async_op, so a per-op
# timeout cannot be wired here — this call inherits the 1800s
# process-group default. A wedged rank will still surface, just
# later than the other fast-fail paths in this patch.
if dist.is_initialized():
object_list = [train_state if self._rank == 0 else None]
dist.broadcast_object_list(object_list, src=0)
Expand Down
42 changes: 40 additions & 2 deletions kempnerforge/distributed/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,42 @@ def _set_nccl_env() -> None:
os.environ.setdefault("NCCL_IB_DISABLE", "0")
os.environ.setdefault("NCCL_NET_GDR_LEVEL", "2")

# Ensure NCCL actually enforces the process-group timeout. The default in
# PyTorch 2.2+ is "1", but a user shell/SLURM prolog may override it to
# "0", at which point the PG timeout becomes advisory and stuck collectives
# can hang indefinitely. Set a safe default and warn loudly if the user
# has explicitly disabled it.
existing = os.environ.get("TORCH_NCCL_ASYNC_ERROR_HANDLING")
if existing == "0":
logger.warning(
"TORCH_NCCL_ASYNC_ERROR_HANDLING=0 detected — NCCL timeouts "
"are advisory; stuck collectives can hang indefinitely."
)
else:
os.environ.setdefault("TORCH_NCCL_ASYNC_ERROR_HANDLING", "1")


def _barrier_with_timeout(seconds: float, reason: str) -> None:
"""dist.barrier with an explicit per-op timeout and a diagnostic log.

The process-group default timeout (``config.nccl_timeout_sec``) is sized
for training collectives (minutes of reduce on large tensors). Init-path
barriers should fail fast so mesh or env misconfiguration does not block
a job for 30 minutes before surfacing a useful error.
"""
work = dist.barrier(async_op=True)
try:
done = work.wait(timeout=timedelta(seconds=seconds)) # type: ignore[reportOptionalMemberAccess]
except RuntimeError as e:
raise RuntimeError(
f"Barrier timed out after {seconds}s during {reason}. "
f"Common causes: MASTER_ADDR/MASTER_PORT disagreement across ranks, "
f"a rank missing from the job, or the IB interface unreachable. "
f"Underlying: {e}"
) from e
if done is False:
raise RuntimeError(f"Barrier timed out after {seconds}s during {reason}.")


def _set_seed(seed: int, rank: int, pp_rank: int = 0) -> None:
"""Set deterministic seeds for reproducibility.
Expand Down Expand Up @@ -209,8 +245,10 @@ def init_distributed(config: DistributedConfig, seed: int = 42) -> DeviceMesh |
mesh_dim_names=tuple(mesh_dims),
)

# Ensure all ranks have finished mesh creation before proceeding
dist.barrier()
# Ensure all ranks have finished mesh creation before proceeding.
# A 60s bound fails fast on mesh misconfiguration rather than inheriting
# the 1800s PG timeout.
_barrier_with_timeout(60.0, reason="DeviceMesh construction")

# Set seed (vary by PP rank for different dropout/stochastic depth per stage)
pp_rank = 0
Expand Down
26 changes: 22 additions & 4 deletions kempnerforge/resilience/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import logging
from dataclasses import dataclass, field
from datetime import timedelta

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -211,19 +212,36 @@ def check_gpu_health(device: int = 0) -> dict[str, bool | str]:
def check_nccl_health(timeout_sec: float = 10.0) -> bool:
"""Check NCCL communication health via a lightweight all-reduce.

The all-reduce runs with ``async_op=True`` so ``work.wait(timeout=...)``
enforces the caller's bound rather than falling back to the
process-group default timeout (``nccl_timeout_sec``, 1800s). Without
that, this function would sit for 30 minutes on a single stuck peer
regardless of the ``timeout_sec`` argument.

Args:
timeout_sec: Timeout for the collective operation.
timeout_sec: Per-operation timeout for the collective. Returns
False if the all-reduce does not complete within this budget.

Returns:
True if the all-reduce succeeded, False on timeout or error.
True on success, False on timeout, error, or world-size mismatch.
"""
if not dist.is_initialized():
return True # No distributed, nothing to check

try:
tensor = torch.ones(1, device="cuda")
# Use a work handle with timeout
dist.all_reduce(tensor)
work = dist.all_reduce(tensor, async_op=True)
# In current PyTorch Work.wait(timeout=...) raises RuntimeError on
# timeout; older/alternate backends may return False instead. Handle
# both so the timeout is honored regardless of version.
try:
done = work.wait(timeout=timedelta(seconds=timeout_sec)) # type: ignore[reportOptionalMemberAccess]
except RuntimeError as e:
logger.warning(f"NCCL health check timed out after {timeout_sec}s: {e}")
return False
if done is False:
logger.warning(f"NCCL health check timed out after {timeout_sec}s")
return False
torch.cuda.synchronize()
expected = dist.get_world_size()
return abs(tensor.item() - expected) < 1e-5
Expand Down
34 changes: 32 additions & 2 deletions kempnerforge/training/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,20 @@

from __future__ import annotations

import logging
import math
from datetime import timedelta

import torch
import torch.distributed as dist

logger = logging.getLogger(__name__)

# Per-operation timeout for the PP eval loss broadcast. Shorter than the
# 1800s process-group default so a diverged PP stage surfaces fast rather
# than freezing eval for half an hour.
_EVAL_BROADCAST_TIMEOUT_SEC = 300.0


@torch.no_grad()
def run_eval(
Expand Down Expand Up @@ -77,8 +86,29 @@ def run_eval(
avg_loss = 0.0

loss_tensor = torch.tensor([avg_loss], device=device)
dist.broadcast(loss_tensor, group_src=pp_size - 1, group=pp_group) # type: ignore[reportOptionalOperand]
avg_loss = loss_tensor[0].item()
work = dist.broadcast(
loss_tensor,
group_src=pp_size - 1, # type: ignore[reportOptionalOperand]
group=pp_group,
async_op=True,
)
try:
done = work.wait(timeout=timedelta(seconds=_EVAL_BROADCAST_TIMEOUT_SEC)) # type: ignore[reportOptionalMemberAccess]
except RuntimeError as e:
logger.error(
f"Eval loss broadcast timed out after {_EVAL_BROADCAST_TIMEOUT_SEC}s; "
f"a PP stage is likely wedged. Reporting nan loss. Underlying: {e}"
)
avg_loss = float("nan")
else:
if done is False:
logger.error(
f"Eval loss broadcast timed out after {_EVAL_BROADCAST_TIMEOUT_SEC}s; "
"reporting nan loss."
)
avg_loss = float("nan")
else:
avg_loss = loss_tensor[0].item()
else:
# --- Standard eval path ---
total_loss = 0.0
Expand Down
Loading
Loading