|
18 | 18 |
|
19 | 19 | _MASTER_ADDR = "127.0.0.1" |
20 | 20 |
|
| 21 | +# Sentinel exit code used by init_and_run_process to signal a port conflict |
| 22 | +# (DistNetworkError during init_process_group). spawn_multiprocess_job detects |
| 23 | +# this code and retries with a fresh port, recovering from the TOCTOU race |
| 24 | +# between _is_port_available() and dist.init_process_group(). |
| 25 | +_PORT_CONFLICT_EXIT_CODE = 2 |
| 26 | + |
| 27 | + |
| 28 | +class _PortConflictError(RuntimeError): |
| 29 | + """Raised internally when a spawned process exits due to a port conflict.""" |
| 30 | + |
21 | 31 |
|
22 | 32 | class _DistGroup: |
23 | 33 | """Global instance to set/get the default process group for distributed ops.""" |
@@ -251,9 +261,15 @@ def init_and_run_process( |
251 | 261 | job, rank, size, port, port_recv_conn=None, port_send_conns=None, **kwargs |
252 | 262 | ): |
253 | 263 | try: |
254 | | - initialize_or_skip( |
255 | | - rank, size, port, port_recv_conn=port_recv_conn, port_send_conns=port_send_conns |
256 | | - ) |
| 264 | + try: |
| 265 | + initialize_or_skip( |
| 266 | + rank, size, port, port_recv_conn=port_recv_conn, port_send_conns=port_send_conns |
| 267 | + ) |
| 268 | + except dist.DistNetworkError: |
| 269 | + # Port conflict: init_process_group failed to bind (EADDRINUSE). |
| 270 | + # Exit with a sentinel code so spawn_multiprocess_job can retry |
| 271 | + # with a fresh port rather than treating this as a test failure. |
| 272 | + sys.exit(_PORT_CONFLICT_EXIT_CODE) |
257 | 273 | job(rank, size, **kwargs) |
258 | 274 | except Exception as e: |
259 | 275 | # Close the input and output queues to parent process can exit. |
@@ -354,13 +370,39 @@ def _join_multiprocess_job(processes): |
354 | 370 | # Check exitcode via hasattr rather than isinstance(p, mp.Process), because |
355 | 371 | # spawn-context processes (SpawnProcess) don't inherit from mp.Process. |
356 | 372 | if hasattr(p, "exitcode"): |
| 373 | + if p.exitcode == _PORT_CONFLICT_EXIT_CODE: |
| 374 | + raise _PortConflictError( |
| 375 | + f"Process {p.pid} exited with port conflict code {p.exitcode}" |
| 376 | + ) |
357 | 377 | assert p.exitcode == 0, f"Process {p.pid} exited with code {p.exitcode}" |
358 | 378 |
|
359 | 379 |
|
360 | | -def spawn_multiprocess_job(job: Callable[[int, int], None], size: Optional[int] = None): |
361 | | - processes = _start_multiprocess_job(job, size) |
362 | | - if processes: |
363 | | - _join_multiprocess_job(processes) |
| 380 | +def spawn_multiprocess_job( |
| 381 | + job: Callable[[int, int], None], size: Optional[int] = None, max_retries: int = 5 |
| 382 | +): |
| 383 | + for attempt in range(max_retries): |
| 384 | + processes = _start_multiprocess_job(job, size) |
| 385 | + if not processes: |
| 386 | + break |
| 387 | + try: |
| 388 | + _join_multiprocess_job(processes) |
| 389 | + break # success |
| 390 | + except _PortConflictError: |
| 391 | + # Kill any surviving sibling processes and retry with a fresh port. |
| 392 | + # This recovers from the TOCTOU race between _is_port_available() and |
| 393 | + # dist.init_process_group() where an external process grabbed the port. |
| 394 | + for p in processes: |
| 395 | + if p.is_alive(): |
| 396 | + p.terminate() |
| 397 | + p.join(timeout=5) |
| 398 | + if attempt == max_retries - 1: |
| 399 | + raise RuntimeError( |
| 400 | + f"Failed to initialize distributed group after {max_retries} " |
| 401 | + "attempts due to repeated port conflicts" |
| 402 | + ) |
| 403 | + ad_logger.warning( |
| 404 | + f"Port conflict on attempt {attempt + 1}/{max_retries}, retrying with new port..." |
| 405 | + ) |
364 | 406 | cleanup() |
365 | 407 |
|
366 | 408 |
|
|
0 commit comments