Skip to content

Commit 6a717bf

Browse files
committed
fix(infra): add two-phase teardown to prevent TCPStore race at shutdown
Problem: During teardown, rank-0 (TCPStore server owner) could exit before peer ranks finished their final NCCL abort, causing noisy "TCPStore.recvValue failed" / "Broken pipe" warnings on stderr. Root cause: All ranks were killed simultaneously without first coordinating a distributed barrier on the CPU (gloo) group to safely tear down NCCL communicators and the TCPStore. Solution: Implement a two-phase teardown protocol: Phase 1 - Engine destroy: call engine.destroy() on every worker concurrently. The engine-side destroy() now executes a CPU barrier (dist.barrier on a gloo process group) followed by dist.destroy_process_group(), ensuring all ranks leave the NCCL collective together. Phase 2 - Process kill: only after the barrier completes, kill the actual processes (Ray: remove placement groups; Slurm: scancel; Local: process tree cleanup). Changes: - engine (fsdp/megatron/archon): add _cpu_group + pre-destroy barrier - train_controller: two-phase destroy (engines first, then workers) - scheduler/ray: _cleanup_workers with ray.wait timeout + PG removal - scheduler/slurm: _destroy_engines_on_workers via HTTP before scancel - scheduler/local: graceful engine teardown before SIGKILL - scheduler_api: add reverse_order param to delete_workers interface - tests: updated test_train_controller, added test_local_scheduler Tested: DPO 4xH20 Ray scheduler - clean teardown, no TCPStore warnings.
1 parent 70acd22 commit 6a717bf

11 files changed

Lines changed: 438 additions & 33 deletions

File tree

areal/api/scheduler_api.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,22 @@ def get_workers(self, role: str, timeout: int | None = None) -> list[Worker]:
106106
raise NotImplementedError()
107107

108108
@abc.abstractmethod
109-
def delete_workers(self, role: str | None = None):
109+
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
110110
"""Stop and clean up worker processes.
111111
112112
Parameters
113113
----------
114114
role : str, optional
115115
Specific role to delete. If None, all workers are deleted
116+
reverse_order : bool, optional
117+
If True, terminate workers in reverse order of their IDs so that
118+
rank-0 (which typically owns the global TCPStore server) is the
119+
last one to be killed. This helps avoid a noisy
120+
``TCPStore.recvValue failed`` warning emitted by NCCL's
121+
HeartbeatMonitor background thread on non-zero ranks during
122+
teardown. Implementations that tear down all workers as a single
123+
atomic operation (e.g. ``scancel`` for Slurm) may safely ignore
124+
this argument. Defaults to False for backward compatibility.
116125
117126
Raises
118127
------

areal/engine/fsdp_engine.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,25 @@ def destroy(self):
419419
# handles still exist and we expect another engine to
420420
# clean up these groups.
421421
if dist.is_initialized() and self.own_global_group:
422+
# Pre-destroy synchronization on a CPU (gloo) group so that all
423+
# ranks leave the NCCL collective phase together. Without this
424+
# barrier, rank-0 (which owns the TCPStore server) may exit
425+
# before peers finish their final NCCL abort, causing
426+
# HeartbeatMonitor background threads on other ranks to observe
427+
# "recvValue failed" on the already-closed store. This is
428+
# harmless but produces a noisy stderr backtrace at teardown.
429+
if getattr(self, "_cpu_group", None) is not None:
430+
try:
431+
dist.barrier(group=self._cpu_group)
432+
except Exception as e: # pragma: no cover - best-effort
433+
self.logger.warning(
434+
f"pre-destroy CPU barrier failed (ignored): {e}"
435+
)
422436
dist.destroy_process_group()
437+
# Make destroy() idempotent: if the controller calls destroy
438+
# more than once (e.g. via cleanup hooks), the second call
439+
# must not try to destroy already-destroyed groups.
440+
self.own_global_group = False
423441

424442
@property
425443
def initialized(self) -> bool:

areal/engine/megatron_engine.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,19 @@ def destroy(self):
510510
# handles still exist and we expect another engine to
511511
# clean up these groups.
512512
if dist.is_initialized() and self.own_global_group:
513+
# Pre-destroy synchronization on a CPU (gloo) group so that all
514+
# ranks leave the NCCL collective phase together. Without this
515+
# barrier, rank-0 (which owns the TCPStore server) may exit
516+
# before peers finish their final NCCL abort, causing
517+
# HeartbeatMonitor background threads on other ranks to observe
518+
# "recvValue failed" on the already-closed store.
519+
if getattr(self, "_cpu_group", None) is not None:
520+
try:
521+
dist.barrier(group=self._cpu_group)
522+
except Exception as e: # pragma: no cover - best-effort
523+
self.logger.warning(
524+
f"pre-destroy CPU barrier failed (ignored): {e}"
525+
)
513526
mpu.destroy_model_parallel()
514527
dist.destroy_process_group()
515528
self.own_global_group = False

areal/experimental/engine/archon_engine.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,25 @@ def destroy(self):
428428
gc.collect()
429429

430430
if dist.is_initialized() and self.own_global_group:
431+
# Pre-destroy synchronization on a CPU (gloo) group so that all
432+
# ranks leave the NCCL collective phase together. Without this
433+
# barrier, rank-0 (which owns the TCPStore server) may exit
434+
# before peers finish their final NCCL abort, causing
435+
# HeartbeatMonitor background threads on other ranks to observe
436+
# "recvValue failed" on the already-closed store. This is
437+
# harmless but produces a noisy stderr backtrace at teardown.
438+
if getattr(self, "_cpu_group", None) is not None:
439+
try:
440+
dist.barrier(group=self._cpu_group)
441+
except Exception as e:
442+
self.logger.warning(
443+
f"pre-destroy CPU barrier failed (ignored): {e}"
444+
)
431445
dist.destroy_process_group()
446+
# Make destroy() idempotent: if the controller calls destroy
447+
# more than once (e.g. via cleanup hooks), the second call
448+
# must not try to destroy already-destroyed groups.
449+
self.own_global_group = False
432450
self._initialized = False
433451

434452
def train(self, mode: bool = True):

areal/infra/controller/train_controller.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,19 @@ def destroy(self):
404404
"""Destroy the controller and release GPU memory of models.
405405
406406
Cleans up all resources including workers, engines, and internal state.
407+
408+
The teardown order is carefully chosen to avoid a noisy
409+
``TCPStore.recvValue failed`` warning from NCCL's HeartbeatMonitor
410+
on non-zero ranks:
411+
412+
1. Remote engines' ``destroy()`` runs first so that every rank calls
413+
``dist.destroy_process_group()`` after a CPU barrier. This
414+
guarantees all ranks finish NCCL abort together before any store
415+
shuts down.
416+
2. Workers are killed in reverse rank order so that rank-0 (owner
417+
of the global TCPStore server) receives SIGTERM last. This
418+
avoids the short window where non-zero ranks' HeartbeatMonitor
419+
threads poll a store whose TCP listener has already been closed.
407420
"""
408421
logger.info("Destroying TrainController...")
409422

@@ -421,17 +434,28 @@ async def _destroy_all_engines():
421434
)
422435
for rank, worker in enumerate(self.workers)
423436
]
424-
await asyncio.gather(*tasks, return_exceptions=True)
425-
426-
run_async_task(_destroy_all_engines)
437+
return await asyncio.gather(*tasks, return_exceptions=True)
438+
439+
results = run_async_task(_destroy_all_engines)
440+
# Surface per-worker failures instead of silently swallowing them.
441+
for rank, res in enumerate(results or []):
442+
if isinstance(res, BaseException):
443+
logger.warning(
444+
f"Engine destroy on rank {rank} raised "
445+
f"{type(res).__name__}: {res}"
446+
)
427447
logger.info("Engines destroyed")
428448
except Exception as e:
429449
logger.error(f"Error destroying engines: {e}")
430450

431-
# Then delete workers via scheduler
451+
# Then delete workers via scheduler. Pass reverse_order=True so
452+
# that rank-0 (TCPStore owner) is killed last. All in-tree
453+
# Scheduler implementations (Local/Ray/Slurm) accept this kwarg;
454+
# third-party subclasses that override ``delete_workers`` must
455+
# adopt the same signature.
432456
try:
433-
logger.info("Deleting all workers...")
434-
self.scheduler.delete_workers(role=self._worker_role)
457+
logger.info("Deleting all workers (reverse rank order)...")
458+
self.scheduler.delete_workers(role=self._worker_role, reverse_order=True)
435459
logger.info("Workers deleted")
436460
except Exception as e:
437461
logger.error(f"Error deleting workers: {e}")

areal/infra/scheduler/local.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,23 +1082,27 @@ def _check_worker_health(self, role: str):
10821082
stderr,
10831083
)
10841084

1085-
def delete_workers(self, role: str | None = None):
1085+
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
10861086
"""Delete workers and clean up resources.
10871087
10881088
Parameters
10891089
----------
10901090
role : str, optional
10911091
Specific worker role to delete, or None to delete all
1092+
reverse_order : bool, optional
1093+
If True, terminate workers in reverse rank order so that rank-0
1094+
(owner of the global TCPStore) is signalled last. See
1095+
``Scheduler.delete_workers`` for background.
10921096
"""
10931097
if role is None:
10941098
# Delete colocated roles first (they don't own processes)
10951099
colocated_roles = list(self._colocated_roles.keys())
10961100
for r in colocated_roles:
1097-
self.delete_workers(r)
1101+
self.delete_workers(r, reverse_order=reverse_order)
10981102
# Then delete actual worker roles
10991103
roles = list(self._workers.keys())
11001104
for r in roles:
1101-
self.delete_workers(r)
1105+
self.delete_workers(r, reverse_order=reverse_order)
11021106
return
11031107

11041108
# Handle colocated/forked role
@@ -1107,6 +1111,8 @@ def delete_workers(self, role: str | None = None):
11071111
if role in self._workers:
11081112
logger.info(f"Removing forked role '{role}' (managed by parent worker)")
11091113
workers = self._workers[role]
1114+
if reverse_order:
1115+
workers = list(reversed(workers))
11101116
self._cleanup_workers(
11111117
workers
11121118
) # Release ports, but process=None skips kill
@@ -1124,29 +1130,108 @@ def delete_workers(self, role: str | None = None):
11241130
workers = self._workers[role]
11251131
logger.info(f"Deleting {len(workers)} workers for role '{role}'")
11261132

1133+
if reverse_order:
1134+
workers = list(reversed(workers))
11271135
self._cleanup_workers(workers)
11281136

11291137
del self._workers[role]
11301138

11311139
logger.info(f"Successfully deleted workers for role '{role}'")
11321140

11331141
def _cleanup_workers(self, workers: list[WorkerInfo]):
1142+
"""Tear down a batch of workers with coordinated teardown semantics.
1143+
1144+
The previous implementation iterated ``workers`` serially and called
1145+
``kill_process_tree(..., timeout=3, graceful=True)`` on each one.
1146+
Because that helper blocks for up to ``timeout`` seconds between
1147+
SIGTERM and the fallback SIGKILL, a 4-rank job could spend ~12 s
1148+
killing workers one-by-one. During that window only a single rank
1149+
was executing its ``engine.destroy()`` path, so the CPU barrier
1150+
added in ``FSDPEngine.destroy()`` could never actually synchronise
1151+
-- every rank timed out on its barrier and the NCCL teardown race
1152+
that produced ``TCPStore.recvValue failed`` / HeartbeatMonitor
1153+
warnings was not fixed.
1154+
1155+
The corrected behaviour is:
1156+
1157+
1. Release port allocations synchronously (cheap, no I/O).
1158+
2. Send SIGTERM to every worker in the order provided by the
1159+
caller, with no blocking waits in between. ``delete_workers``
1160+
passes the list in reverse rank order when
1161+
``reverse_order=True``, which preserves the "rank-0 signalled
1162+
last" guarantee while keeping the dispatch window in the
1163+
millisecond range.
1164+
3. Wait for every worker to exit in parallel using one thread per
1165+
worker. Each thread re-uses ``kill_process_tree`` so the
1166+
existing SIGTERM -> wait -> SIGKILL escalation is preserved
1167+
per-worker; we just no longer serialise the waits.
1168+
1169+
With this change every rank enters ``engine.destroy()`` within the
1170+
same small window, the CPU ``dist.barrier`` inside can actually
1171+
rendezvous, and the NCCL / TCPStore teardown becomes race-free.
1172+
"""
1173+
import threading
1174+
1175+
# Phase 1: always release ports, regardless of whether the worker
1176+
# owns a process (forked workers have ``process is None``).
1177+
live_workers: list[WorkerInfo] = []
11341178
for worker_info in workers:
11351179
try:
11361180
for port_str in worker_info.worker.worker_ports:
11371181
self._allocated_ports.discard(int(port_str))
1182+
except Exception as e:
1183+
logger.error(
1184+
f"Error releasing ports for worker {worker_info.worker.id}: {e}",
1185+
exc_info=True,
1186+
)
1187+
if worker_info.process is not None:
1188+
live_workers.append(worker_info)
1189+
else:
1190+
logger.debug(f"Cleaned up worker {worker_info.worker.id}")
11381191

1139-
# Only kill process if we own it (non-forked workers)
1140-
if worker_info.process is not None:
1141-
kill_process_tree(worker_info.process.pid, timeout=3, graceful=True)
1192+
if not live_workers:
1193+
return
11421194

1195+
# Phase 2: dispatch SIGTERM to every worker concurrently via
1196+
# background threads so that all ranks reach their teardown
1197+
# barrier within the same window. The list order is preserved as
1198+
# thread start order: when the caller requests reverse_order,
1199+
# rank-0 is the last thread to be started, which keeps the
1200+
# "rank-0 dies last" property while staying non-blocking.
1201+
def _finalize(worker_info: WorkerInfo) -> None:
1202+
try:
1203+
kill_process_tree(worker_info.process.pid, timeout=3, graceful=True)
11431204
logger.debug(f"Cleaned up worker {worker_info.worker.id}")
11441205
except Exception as e:
11451206
logger.error(
11461207
f"Error cleaning up worker {worker_info.worker.id}: {e}",
11471208
exc_info=True,
11481209
)
11491210

1211+
threads: list[threading.Thread] = []
1212+
for worker_info in live_workers:
1213+
t = threading.Thread(
1214+
target=_finalize,
1215+
args=(worker_info,),
1216+
name=f"cleanup-{worker_info.worker.id}",
1217+
daemon=True,
1218+
)
1219+
t.start()
1220+
threads.append(t)
1221+
1222+
# Phase 3: wait for every cleanup thread. Each ``kill_process_tree``
1223+
# call internally waits up to ``timeout=3`` seconds for graceful
1224+
# shutdown and then SIGKILLs stragglers, so a small safety margin
1225+
# on ``join`` is sufficient.
1226+
join_timeout = 10.0
1227+
for t in threads:
1228+
t.join(timeout=join_timeout)
1229+
if t.is_alive():
1230+
logger.warning(
1231+
f"Cleanup thread {t.name} did not finish within "
1232+
f"{join_timeout}s; leaving it as daemon."
1233+
)
1234+
11501235
def _read_log_tail(self, log_file: str, lines: int = 50) -> str:
11511236
try:
11521237
with open(log_file) as f:

0 commit comments

Comments
 (0)