Skip to content

Commit 29b8495

Browse files
committed
fix(controller): tear down workers in reverse rank order with concurrent dispatch
When ``TrainController.destroy()`` asks the scheduler to kill its workers, rank-0 is now signalled last instead of first. Rank-0 hosts the global TCPStore server that all other ranks' ``ProcessGroupNCCL::HeartbeatMonitor`` threads still poll during their final cleanup; killing it first leaves peers observing a closed socket, which surfaces as [W TCPStore.cpp] recvValue failed ... no error [W ProcessGroupNCCL.cpp] ... HeartbeatMonitor::runLoop() in stderr at the very end of a successful run. Reverse rank order alone is necessary but not sufficient: the original ``LocalScheduler._cleanup_workers`` iterated workers serially and blocked on ``kill_process_tree(..., timeout=3, graceful=True)`` for each one. A 4-rank job therefore spent ~12s in cleanup, with only one rank inside its ``engine.destroy()`` path at a time. The CPU ``dist.barrier`` added in the companion FSDP commit could never rendezvous -- every rank timed out and the NCCL/TCPStore teardown race still fired. The local scheduler now dispatches SIGTERM to all workers concurrently via daemon threads and joins them in parallel, so every rank enters ``engine.destroy()`` within the same small window while rank-0 still receives its signal last (by a few milliseconds). Changes ------- * ``Scheduler.delete_workers`` grows an optional ``reverse_order: bool`` keyword, documented in the abstract API. Existing callers stay source-compatible. * ``LocalScheduler._cleanup_workers`` is restructured into three phases: synchronous port release, concurrent SIGTERM dispatch (one daemon thread per worker, started in caller-provided order), and parallel join with a bounded timeout. Per-worker SIGTERM -> wait -> SIGKILL escalation is preserved via the existing ``kill_process_tree`` helper. * ``RayScheduler`` honours the flag by iterating workers in reverse before invoking ``actor.destroy.remote()``. No concurrency change is needed because Ray's ``.remote()`` is already async-dispatched. * ``SlurmScheduler`` accepts the keyword for API parity but ignores it, since ``scancel`` tears down the whole job step atomically. * ``TrainController.destroy()`` now: - passes ``reverse_order=True`` with a ``TypeError`` fallback so third-party schedulers keep working; - inspects the ``asyncio.gather(..., return_exceptions=True)`` result and logs per-rank engine-destroy failures as warnings instead of silently discarding them; - documents the new two-phase teardown invariant in its docstring. * Mock schedulers in ``tests/test_train_controller.py`` and ``tests/test_rollout_controller.py`` accept the new kwarg. Tests ----- * ``tests/test_train_controller.py`` asserts ``delete_workers`` is called with ``reverse_order=True`` and verifies the ``TypeError`` fallback path for legacy schedulers. * ``tests/test_local_scheduler.py`` verifies that ``reverse_order=True`` produces the expected reverse iteration over workers. * Verified end-to-end with the HH-RLHF DPO example under ``scheduler.type=local``: the previously reproducible ``TCPStore.recvValue failed`` / HeartbeatMonitor warnings no longer appear on clean shutdown.
1 parent e7230e2 commit 29b8495

8 files changed

Lines changed: 230 additions & 19 deletions

File tree

areal/api/scheduler_api.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,22 @@ def get_workers(self, role: str, timeout: int | None = None) -> list[Worker]:
106106
raise NotImplementedError()
107107

108108
@abc.abstractmethod
109-
def delete_workers(self, role: str | None = None):
109+
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
110110
"""Stop and clean up worker processes.
111111
112112
Parameters
113113
----------
114114
role : str, optional
115115
Specific role to delete. If None, all workers are deleted
116+
reverse_order : bool, optional
117+
If True, terminate workers in reverse order of their IDs so that
118+
rank-0 (which typically owns the global TCPStore server) is the
119+
last one to be killed. This helps avoid a noisy
120+
``TCPStore.recvValue failed`` warning emitted by NCCL's
121+
HeartbeatMonitor background thread on non-zero ranks during
122+
teardown. Implementations that tear down all workers as a single
123+
atomic operation (e.g. ``scancel`` for Slurm) may safely ignore
124+
this argument. Defaults to False for backward compatibility.
116125
117126
Raises
118127
------

areal/infra/controller/train_controller.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,19 @@ def destroy(self):
404404
"""Destroy the controller and release GPU memory of models.
405405
406406
Cleans up all resources including workers, engines, and internal state.
407+
408+
The teardown order is carefully chosen to avoid a noisy
409+
``TCPStore.recvValue failed`` warning from NCCL's HeartbeatMonitor
410+
on non-zero ranks:
411+
412+
1. Remote engines' ``destroy()`` runs first so that every rank calls
413+
``dist.destroy_process_group()`` after a CPU barrier. This
414+
guarantees all ranks finish NCCL abort together before any store
415+
shuts down.
416+
2. Workers are killed in reverse rank order so that rank-0 (owner
417+
of the global TCPStore server) receives SIGTERM last. This
418+
avoids the short window where non-zero ranks' HeartbeatMonitor
419+
threads poll a store whose TCP listener has already been closed.
407420
"""
408421
logger.info("Destroying TrainController...")
409422

@@ -421,17 +434,38 @@ async def _destroy_all_engines():
421434
)
422435
for rank, worker in enumerate(self.workers)
423436
]
424-
await asyncio.gather(*tasks, return_exceptions=True)
425-
426-
run_async_task(_destroy_all_engines)
437+
return await asyncio.gather(*tasks, return_exceptions=True)
438+
439+
results = run_async_task(_destroy_all_engines)
440+
# Surface per-worker failures instead of silently swallowing them.
441+
for rank, res in enumerate(results or []):
442+
if isinstance(res, BaseException):
443+
logger.warning(
444+
f"Engine destroy on rank {rank} raised "
445+
f"{type(res).__name__}: {res}"
446+
)
427447
logger.info("Engines destroyed")
428448
except Exception as e:
429449
logger.error(f"Error destroying engines: {e}")
430450

431-
# Then delete workers via scheduler
451+
# Then delete workers via scheduler. Pass reverse_order=True so
452+
# that rank-0 (TCPStore owner) is killed last; keep a TypeError
453+
# fallback for third-party Scheduler implementations that do not
454+
# yet support the new keyword.
432455
try:
433-
logger.info("Deleting all workers...")
434-
self.scheduler.delete_workers(role=self._worker_role)
456+
logger.info("Deleting all workers (reverse rank order)...")
457+
try:
458+
self.scheduler.delete_workers(
459+
role=self._worker_role, reverse_order=True
460+
)
461+
except TypeError:
462+
# Backward-compat path for custom schedulers that have not
463+
# been updated to accept `reverse_order`.
464+
logger.warning(
465+
"Scheduler.delete_workers does not accept reverse_order; "
466+
"falling back to legacy behaviour."
467+
)
468+
self.scheduler.delete_workers(role=self._worker_role)
435469
logger.info("Workers deleted")
436470
except Exception as e:
437471
logger.error(f"Error deleting workers: {e}")

areal/infra/scheduler/local.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,23 +1082,27 @@ def _check_worker_health(self, role: str):
10821082
stderr,
10831083
)
10841084

1085-
def delete_workers(self, role: str | None = None):
1085+
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
10861086
"""Delete workers and clean up resources.
10871087
10881088
Parameters
10891089
----------
10901090
role : str, optional
10911091
Specific worker role to delete, or None to delete all
1092+
reverse_order : bool, optional
1093+
If True, terminate workers in reverse rank order so that rank-0
1094+
(owner of the global TCPStore) is signalled last. See
1095+
``Scheduler.delete_workers`` for background.
10921096
"""
10931097
if role is None:
10941098
# Delete colocated roles first (they don't own processes)
10951099
colocated_roles = list(self._colocated_roles.keys())
10961100
for r in colocated_roles:
1097-
self.delete_workers(r)
1101+
self.delete_workers(r, reverse_order=reverse_order)
10981102
# Then delete actual worker roles
10991103
roles = list(self._workers.keys())
11001104
for r in roles:
1101-
self.delete_workers(r)
1105+
self.delete_workers(r, reverse_order=reverse_order)
11021106
return
11031107

11041108
# Handle colocated/forked role
@@ -1107,6 +1111,8 @@ def delete_workers(self, role: str | None = None):
11071111
if role in self._workers:
11081112
logger.info(f"Removing forked role '{role}' (managed by parent worker)")
11091113
workers = self._workers[role]
1114+
if reverse_order:
1115+
workers = list(reversed(workers))
11101116
self._cleanup_workers(
11111117
workers
11121118
) # Release ports, but process=None skips kill
@@ -1124,29 +1130,108 @@ def delete_workers(self, role: str | None = None):
11241130
workers = self._workers[role]
11251131
logger.info(f"Deleting {len(workers)} workers for role '{role}'")
11261132

1133+
if reverse_order:
1134+
workers = list(reversed(workers))
11271135
self._cleanup_workers(workers)
11281136

11291137
del self._workers[role]
11301138

11311139
logger.info(f"Successfully deleted workers for role '{role}'")
11321140

11331141
def _cleanup_workers(self, workers: list[WorkerInfo]):
1142+
"""Tear down a batch of workers with coordinated teardown semantics.
1143+
1144+
The previous implementation iterated ``workers`` serially and called
1145+
``kill_process_tree(..., timeout=3, graceful=True)`` on each one.
1146+
Because that helper blocks for up to ``timeout`` seconds between
1147+
SIGTERM and the fallback SIGKILL, a 4-rank job could spend ~12 s
1148+
killing workers one-by-one. During that window only a single rank
1149+
was executing its ``engine.destroy()`` path, so the CPU barrier
1150+
added in ``FSDPEngine.destroy()`` could never actually synchronise
1151+
-- every rank timed out on its barrier and the NCCL teardown race
1152+
that produced ``TCPStore.recvValue failed`` / HeartbeatMonitor
1153+
warnings was not fixed.
1154+
1155+
The corrected behaviour is:
1156+
1157+
1. Release port allocations synchronously (cheap, no I/O).
1158+
2. Send SIGTERM to every worker in the order provided by the
1159+
caller, with no blocking waits in between. ``delete_workers``
1160+
passes the list in reverse rank order when
1161+
``reverse_order=True``, which preserves the "rank-0 signalled
1162+
last" guarantee while keeping the dispatch window in the
1163+
millisecond range.
1164+
3. Wait for every worker to exit in parallel using one thread per
1165+
worker. Each thread re-uses ``kill_process_tree`` so the
1166+
existing SIGTERM -> wait -> SIGKILL escalation is preserved
1167+
per-worker; we just no longer serialise the waits.
1168+
1169+
With this change every rank enters ``engine.destroy()`` within the
1170+
same small window, the CPU ``dist.barrier`` inside can actually
1171+
rendezvous, and the NCCL / TCPStore teardown becomes race-free.
1172+
"""
1173+
import threading
1174+
1175+
# Phase 1: always release ports, regardless of whether the worker
1176+
# owns a process (forked workers have ``process is None``).
1177+
live_workers: list[WorkerInfo] = []
11341178
for worker_info in workers:
11351179
try:
11361180
for port_str in worker_info.worker.worker_ports:
11371181
self._allocated_ports.discard(int(port_str))
1182+
except Exception as e:
1183+
logger.error(
1184+
f"Error releasing ports for worker {worker_info.worker.id}: {e}",
1185+
exc_info=True,
1186+
)
1187+
if worker_info.process is not None:
1188+
live_workers.append(worker_info)
1189+
else:
1190+
logger.debug(f"Cleaned up worker {worker_info.worker.id}")
11381191

1139-
# Only kill process if we own it (non-forked workers)
1140-
if worker_info.process is not None:
1141-
kill_process_tree(worker_info.process.pid, timeout=3, graceful=True)
1192+
if not live_workers:
1193+
return
11421194

1195+
# Phase 2: dispatch SIGTERM to every worker concurrently via
1196+
# background threads so that all ranks reach their teardown
1197+
# barrier within the same window. The list order is preserved as
1198+
# thread start order: when the caller requests reverse_order,
1199+
# rank-0 is the last thread to be started, which keeps the
1200+
# "rank-0 dies last" property while staying non-blocking.
1201+
def _finalize(worker_info: WorkerInfo) -> None:
1202+
try:
1203+
kill_process_tree(worker_info.process.pid, timeout=3, graceful=True)
11431204
logger.debug(f"Cleaned up worker {worker_info.worker.id}")
11441205
except Exception as e:
11451206
logger.error(
11461207
f"Error cleaning up worker {worker_info.worker.id}: {e}",
11471208
exc_info=True,
11481209
)
11491210

1211+
threads: list[threading.Thread] = []
1212+
for worker_info in live_workers:
1213+
t = threading.Thread(
1214+
target=_finalize,
1215+
args=(worker_info,),
1216+
name=f"cleanup-{worker_info.worker.id}",
1217+
daemon=True,
1218+
)
1219+
t.start()
1220+
threads.append(t)
1221+
1222+
# Phase 3: wait for every cleanup thread. Each ``kill_process_tree``
1223+
# call internally waits up to ``timeout=3`` seconds for graceful
1224+
# shutdown and then SIGKILLs stragglers, so a small safety margin
1225+
# on ``join`` is sufficient.
1226+
join_timeout = 10.0
1227+
for t in threads:
1228+
t.join(timeout=join_timeout)
1229+
if t.is_alive():
1230+
logger.warning(
1231+
f"Cleanup thread {t.name} did not finish within "
1232+
f"{join_timeout}s; leaving it as daemon."
1233+
)
1234+
11501235
def _read_log_tail(self, log_file: str, lines: int = 50) -> str:
11511236
try:
11521237
with open(log_file) as f:

areal/infra/scheduler/ray.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -492,24 +492,28 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]:
492492

493493
return [wi.worker for wi in worker_info_list]
494494

495-
def delete_workers(self, role: str | None = None):
495+
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
496496
"""
497497
Delete workers and clean up resources
498498
499499
Parameters
500500
--------
501501
role: str, optional
502502
Specific worker role to delete, or None to delete all
503+
reverse_order: bool, optional
504+
If True, iterate workers in reverse rank order when issuing
505+
``actor.destroy.remote()`` so that rank-0 is signalled last.
506+
Note: Ray kills are asynchronous, so ordering here is best-effort.
503507
"""
504508
if role is None:
505509
# Delete colocated roles first (they're just mappings)
506510
colocated_roles = list(self._colocated_roles.keys())
507511
for r in colocated_roles:
508-
self.delete_workers(r)
512+
self.delete_workers(r, reverse_order=reverse_order)
509513
# Then delete actual worker roles
510514
roles = list(self._workers.keys())
511515
for r in roles:
512-
self.delete_workers(r)
516+
self.delete_workers(r, reverse_order=reverse_order)
513517
return
514518

515519
# Handle colocated role
@@ -521,6 +525,8 @@ def delete_workers(self, role: str | None = None):
521525
logger.info(
522526
f"Cleaning up {len(workers)} forked actors for role '{role}'"
523527
)
528+
if reverse_order:
529+
workers = list(reversed(workers))
524530
self._cleanup_forked_workers(workers)
525531
del self._workers[role]
526532
else:
@@ -536,6 +542,8 @@ def delete_workers(self, role: str | None = None):
536542
workers = self._workers[role]
537543
logger.info(f"Deleting {len(workers)} workers for role '{role}'")
538544

545+
if reverse_order:
546+
workers = list(reversed(workers))
539547
self._cleanup_workers(workers)
540548

541549
del self._workers[role]

areal/infra/scheduler/slurm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1221,14 +1221,19 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]:
12211221

12221222
raise WorkerTimeoutError(role, timeout)
12231223

1224-
def delete_workers(self, role: str | None = None):
1224+
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
12251225
"""Delete workers and cancel Slurm jobs.
12261226
12271227
Parameters
12281228
----------
12291229
role : str, optional
12301230
Role to delete. If None, deletes all roles.
1231+
reverse_order : bool, optional
1232+
Accepted for API compatibility with other schedulers but ignored
1233+
here: Slurm tears down the entire job step atomically via
1234+
``scancel``, so per-rank ordering cannot be enforced.
12311235
"""
1236+
del reverse_order # unused, see docstring
12321237
if role is None:
12331238
# Delete colocated/forked roles first (they don't own Slurm jobs)
12341239
colocated_roles = list(self._colocated_roles.keys())

tests/test_local_scheduler.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,39 @@ def test_delete_workers_nonexistent_role(self, scheduler):
11001100
# Should not raise
11011101
scheduler.delete_workers("nonexistent")
11021102

1103+
def test_delete_workers_reverse_order(self, scheduler, tmp_path, monkeypatch):
1104+
"""With reverse_order=True, workers are cleaned up in reverse rank order.
1105+
1106+
This protects rank-0 (owner of the global TCPStore server) from being
1107+
torn down before non-zero ranks finish their final NCCL abort.
1108+
"""
1109+
workers = [
1110+
create_worker_info(
1111+
worker_id=f"role1/{i}",
1112+
role="role1",
1113+
ports=[str(8000 + i)],
1114+
log_file=str(tmp_path / f"role1-{i}.log"),
1115+
)
1116+
for i in range(4)
1117+
]
1118+
scheduler._workers["role1"] = workers
1119+
scheduler._allocated_ports = {8000, 8001, 8002, 8003}
1120+
1121+
observed_order: list[str] = []
1122+
1123+
original_cleanup = scheduler._cleanup_workers
1124+
1125+
def spy(workers_arg):
1126+
observed_order.extend(w.worker.id for w in workers_arg)
1127+
original_cleanup(workers_arg)
1128+
1129+
monkeypatch.setattr(scheduler, "_cleanup_workers", spy)
1130+
1131+
scheduler.delete_workers("role1", reverse_order=True)
1132+
1133+
assert observed_order == ["role1/3", "role1/2", "role1/1", "role1/0"]
1134+
assert "role1" not in scheduler._workers
1135+
11031136
def test_cleanup_workers_releases_ports(self, scheduler, tmp_path):
11041137
"""Should release allocated ports when cleaning up workers."""
11051138
worker = create_worker_info(

tests/test_rollout_controller.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ async def _async_call_engine_internal(self, worker_id, method, *args, **kwargs):
161161
await asyncio.sleep(0.001)
162162
return None
163163

164-
def delete_workers(self, role):
164+
def delete_workers(self, role, reverse_order: bool = False):
165165
self.workers.clear()
166166
self._pending_results.clear()
167167
self._task_counter = 0

0 commit comments

Comments
 (0)