Enforce per-op timeouts on init and coordination collectives#64
Open
Enforce per-op timeouts on init and coordination collectives#64
Conversation
check_nccl_health accepted a timeout_sec parameter and then called dist.all_reduce synchronously, inheriting the 1800s PG default instead. Every caller asking for a 10s liveness probe got a 30-minute one. Switch to async_op=True + Work.wait(timeout=timedelta(...)) so the caller's bound is actually enforced, and handle both the RuntimeError and legacy False-return timeout paths. Add _barrier_with_timeout for the DeviceMesh-construction barrier (60s) so a stuck rank during init fails fast instead of after 30min. Wrap the PP eval loss broadcast in a 300s budget that falls back to nan loss, and the train_state broadcast_object_list in a 600s budget that raises a diagnostic RuntimeError — both previously inherited the PG default. Force TORCH_NCCL_ASYNC_ERROR_HANDLING=1 in _set_nccl_env when unset, and warn loudly when a user has explicitly set it to 0 (at which point NCCL timeouts become advisory and none of the other bounds matter).
dist.broadcast_object_list does not accept async_op in PyTorch 2.11, so the sub-fix added in 330e158 raised TypeError on every distributed load. The init-path barrier, NCCL health, and eval-loss paths in that patch all remain fast-fail; only the object broadcast inherits the 1800s PG default now, which matches the pre-patch behavior.
PyTorch's distributed stubs type dist.barrier/all_reduce/broadcast as returning Work | None because the sync path (async_op=False) returns None. Pyright cannot narrow the return when async_op=True is a literal kwarg, so every .wait(...) on the async handle flags reportOptionalMemberAccess. Add the same # type: ignore the rest of the codebase uses at three sites: _barrier_with_timeout, the eval-loss broadcast, and check_nccl_health.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #63.
What
Make declared per-op timeouts on init and coordination collectives actually enforce. Today they fall through to the 1800s process-group default, so a wedged peer hangs the job for 30 minutes regardless of the caller's intent.
kempnerforge/resilience/health.py::check_nccl_healthswitches toasync_op=True+Work.wait(timeout=timedelta(seconds=timeout_sec))and returnsFalseon either timeout signaling path (RuntimeError on modern PyTorch,Falsereturn on legacy backends).kempnerforge/distributed/setup.pyadds_barrier_with_timeout(seconds, reason)and uses it for the DeviceMesh-construction barrier with a 60s budget. Timeouts raiseRuntimeErrorwith the reason string attached so the failure points at the right phase of init.kempnerforge/training/eval.pywraps the PP eval loss broadcast in a 300s budget with ananfallback so a wedged PP stage surfaces as a clear bad-iteration signal instead of a half-hour hang.kempnerforge/distributed/setup.py::_set_nccl_envsetdefaultsTORCH_NCCL_ASYNC_ERROR_HANDLING=1and logs a warning when a user has explicitly set it to0. Without this, every per-op bound in this patch is advisory.Known limit, train_state broadcast
The original commit also wired a 600s
async_op=Truebudget around the train_state broadcast inCheckpointManager.load. PyTorch 2.11'sdist.broadcast_object_listdoes not acceptasync_op(TypeError at runtime), so the second commit on this branch reverts that one site to the plain synchronous form and leaves an explanatory comment. The other three call sites still get fast-fail; a wedged rank surfaces via those before the broadcast hits the 1800s default.Tests
tests/unit/test_collective_timeouts.py(10 new tests, all passing locally):TestCheckNcclHealthTimeoutcoversWork.wait(timedelta(...))forwarding,RuntimeErrorandFalse-return timeout paths, and the not-initialized early return.TestBarrierWithTimeoutcovers the reason-string error message, theFalse-return path, and the success path.TestNcclAsyncErrorHandlingEnvGuardcovers the unset/explicit-1/explicit-0 cases forTORCH_NCCL_ASYNC_ERROR_HANDLING.The test file documents the absence of a
TestCheckpointBroadcastTimeoutclass with a comment pointing at the PyTorch 2.11 limitation.