@@ -90,6 +90,42 @@ def _set_nccl_env() -> None:
9090 os .environ .setdefault ("NCCL_IB_DISABLE" , "0" )
9191 os .environ .setdefault ("NCCL_NET_GDR_LEVEL" , "2" )
9292
93+ # Ensure NCCL actually enforces the process-group timeout. The default in
94+ # PyTorch 2.2+ is "1", but a user shell/SLURM prolog may override it to
95+ # "0", at which point the PG timeout becomes advisory and stuck collectives
96+ # can hang indefinitely. Set a safe default and warn loudly if the user
97+ # has explicitly disabled it.
98+ existing = os .environ .get ("TORCH_NCCL_ASYNC_ERROR_HANDLING" )
99+ if existing == "0" :
100+ logger .warning (
101+ "TORCH_NCCL_ASYNC_ERROR_HANDLING=0 detected — NCCL timeouts "
102+ "are advisory; stuck collectives can hang indefinitely."
103+ )
104+ else :
105+ os .environ .setdefault ("TORCH_NCCL_ASYNC_ERROR_HANDLING" , "1" )
106+
107+
108+ def _barrier_with_timeout (seconds : float , reason : str ) -> None :
109+ """dist.barrier with an explicit per-op timeout and a diagnostic log.
110+
111+ The process-group default timeout (``config.nccl_timeout_sec``) is sized
112+ for training collectives (minutes of reduce on large tensors). Init-path
113+ barriers should fail fast so mesh or env misconfiguration does not block
114+ a job for 30 minutes before surfacing a useful error.
115+ """
116+ work = dist .barrier (async_op = True )
117+ try :
118+ done = work .wait (timeout = timedelta (seconds = seconds )) # type: ignore[reportOptionalMemberAccess]
119+ except RuntimeError as e :
120+ raise RuntimeError (
121+ f"Barrier timed out after { seconds } s during { reason } . "
122+ f"Common causes: MASTER_ADDR/MASTER_PORT disagreement across ranks, "
123+ f"a rank missing from the job, or the IB interface unreachable. "
124+ f"Underlying: { e } "
125+ ) from e
126+ if done is False :
127+ raise RuntimeError (f"Barrier timed out after { seconds } s during { reason } ." )
128+
93129
94130def _set_seed (seed : int , rank : int , pp_rank : int = 0 ) -> None :
95131 """Set deterministic seeds for reproducibility.
@@ -223,8 +259,10 @@ def init_distributed(config: DistributedConfig, seed: int = 42) -> DeviceMesh |
223259 mesh_dim_names = tuple (mesh_dims ),
224260 )
225261
226- # Ensure all ranks have finished mesh creation before proceeding
227- dist .barrier ()
262+ # Ensure all ranks have finished mesh creation before proceeding.
263+ # A 60s bound fails fast on mesh misconfiguration rather than inheriting
264+ # the 1800s PG timeout.
265+ _barrier_with_timeout (60.0 , reason = "DeviceMesh construction" )
228266
229267 # Set seed (vary by PP rank for different dropout/stochastic depth per stage)
230268 pp_rank = 0
0 commit comments