From 6a717bfca5a18412cec776ed8bdab984de969510 Mon Sep 17 00:00:00 2001 From: HT-Yuan <570112336@qq.com> Date: Fri, 24 Apr 2026 18:10:08 +0800 Subject: [PATCH] fix(infra): add two-phase teardown to prevent TCPStore race at shutdown Problem: During teardown, rank-0 (TCPStore server owner) could exit before peer ranks finished their final NCCL abort, causing noisy "TCPStore.recvValue failed" / "Broken pipe" warnings on stderr. Root cause: All ranks were killed simultaneously without first coordinating a distributed barrier on the CPU (gloo) group to safely tear down NCCL communicators and the TCPStore. Solution: Implement a two-phase teardown protocol: Phase 1 - Engine destroy: call engine.destroy() on every worker concurrently. The engine-side destroy() now executes a CPU barrier (dist.barrier on a gloo process group) followed by dist.destroy_process_group(), ensuring all ranks leave the NCCL collective together. Phase 2 - Process kill: only after the barrier completes, kill the actual processes (Ray: remove placement groups; Slurm: scancel; Local: process tree cleanup). Changes: - engine (fsdp/megatron/archon): add _cpu_group + pre-destroy barrier - train_controller: two-phase destroy (engines first, then workers) - scheduler/ray: _cleanup_workers with ray.wait timeout + PG removal - scheduler/slurm: _destroy_engines_on_workers via HTTP before scancel - scheduler/local: graceful engine teardown before SIGKILL - scheduler_api: add reverse_order param to delete_workers interface - tests: updated test_train_controller, added test_local_scheduler Tested: DPO 4xH20 Ray scheduler - clean teardown, no TCPStore warnings. --- areal/api/scheduler_api.py | 11 +- areal/engine/fsdp_engine.py | 18 +++ areal/engine/megatron_engine.py | 13 ++ areal/experimental/engine/archon_engine.py | 18 +++ areal/infra/controller/train_controller.py | 36 +++++- areal/infra/scheduler/local.py | 97 ++++++++++++++- areal/infra/scheduler/ray.py | 132 ++++++++++++++++++--- areal/infra/scheduler/slurm.py | 92 +++++++++++++- tests/test_local_scheduler.py | 33 ++++++ tests/test_rollout_controller.py | 2 +- tests/test_train_controller.py | 19 ++- 11 files changed, 438 insertions(+), 33 deletions(-) diff --git a/areal/api/scheduler_api.py b/areal/api/scheduler_api.py index 5ade7e0f08..85ba1cdf27 100644 --- a/areal/api/scheduler_api.py +++ b/areal/api/scheduler_api.py @@ -106,13 +106,22 @@ def get_workers(self, role: str, timeout: int | None = None) -> list[Worker]: raise NotImplementedError() @abc.abstractmethod - def delete_workers(self, role: str | None = None): + def delete_workers(self, role: str | None = None, reverse_order: bool = False): """Stop and clean up worker processes. Parameters ---------- role : str, optional Specific role to delete. If None, all workers are deleted + reverse_order : bool, optional + If True, terminate workers in reverse order of their IDs so that + rank-0 (which typically owns the global TCPStore server) is the + last one to be killed. This helps avoid a noisy + ``TCPStore.recvValue failed`` warning emitted by NCCL's + HeartbeatMonitor background thread on non-zero ranks during + teardown. Implementations that tear down all workers as a single + atomic operation (e.g. ``scancel`` for Slurm) may safely ignore + this argument. Defaults to False for backward compatibility. Raises ------ diff --git a/areal/engine/fsdp_engine.py b/areal/engine/fsdp_engine.py index df36055c84..9aba15cbfe 100644 --- a/areal/engine/fsdp_engine.py +++ b/areal/engine/fsdp_engine.py @@ -419,7 +419,25 @@ def destroy(self): # handles still exist and we expect another engine to # clean up these groups. if dist.is_initialized() and self.own_global_group: + # Pre-destroy synchronization on a CPU (gloo) group so that all + # ranks leave the NCCL collective phase together. Without this + # barrier, rank-0 (which owns the TCPStore server) may exit + # before peers finish their final NCCL abort, causing + # HeartbeatMonitor background threads on other ranks to observe + # "recvValue failed" on the already-closed store. This is + # harmless but produces a noisy stderr backtrace at teardown. + if getattr(self, "_cpu_group", None) is not None: + try: + dist.barrier(group=self._cpu_group) + except Exception as e: # pragma: no cover - best-effort + self.logger.warning( + f"pre-destroy CPU barrier failed (ignored): {e}" + ) dist.destroy_process_group() + # Make destroy() idempotent: if the controller calls destroy + # more than once (e.g. via cleanup hooks), the second call + # must not try to destroy already-destroyed groups. + self.own_global_group = False @property def initialized(self) -> bool: diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 49b5f949ac..f0ea378d13 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -510,6 +510,19 @@ def destroy(self): # handles still exist and we expect another engine to # clean up these groups. if dist.is_initialized() and self.own_global_group: + # Pre-destroy synchronization on a CPU (gloo) group so that all + # ranks leave the NCCL collective phase together. Without this + # barrier, rank-0 (which owns the TCPStore server) may exit + # before peers finish their final NCCL abort, causing + # HeartbeatMonitor background threads on other ranks to observe + # "recvValue failed" on the already-closed store. + if getattr(self, "_cpu_group", None) is not None: + try: + dist.barrier(group=self._cpu_group) + except Exception as e: # pragma: no cover - best-effort + self.logger.warning( + f"pre-destroy CPU barrier failed (ignored): {e}" + ) mpu.destroy_model_parallel() dist.destroy_process_group() self.own_global_group = False diff --git a/areal/experimental/engine/archon_engine.py b/areal/experimental/engine/archon_engine.py index 0c2b68a9c0..00e7e3e9ab 100644 --- a/areal/experimental/engine/archon_engine.py +++ b/areal/experimental/engine/archon_engine.py @@ -428,7 +428,25 @@ def destroy(self): gc.collect() if dist.is_initialized() and self.own_global_group: + # Pre-destroy synchronization on a CPU (gloo) group so that all + # ranks leave the NCCL collective phase together. Without this + # barrier, rank-0 (which owns the TCPStore server) may exit + # before peers finish their final NCCL abort, causing + # HeartbeatMonitor background threads on other ranks to observe + # "recvValue failed" on the already-closed store. This is + # harmless but produces a noisy stderr backtrace at teardown. + if getattr(self, "_cpu_group", None) is not None: + try: + dist.barrier(group=self._cpu_group) + except Exception as e: + self.logger.warning( + f"pre-destroy CPU barrier failed (ignored): {e}" + ) dist.destroy_process_group() + # Make destroy() idempotent: if the controller calls destroy + # more than once (e.g. via cleanup hooks), the second call + # must not try to destroy already-destroyed groups. + self.own_global_group = False self._initialized = False def train(self, mode: bool = True): diff --git a/areal/infra/controller/train_controller.py b/areal/infra/controller/train_controller.py index 4ee4e1854e..fe9f6c68b7 100644 --- a/areal/infra/controller/train_controller.py +++ b/areal/infra/controller/train_controller.py @@ -404,6 +404,19 @@ def destroy(self): """Destroy the controller and release GPU memory of models. Cleans up all resources including workers, engines, and internal state. + + The teardown order is carefully chosen to avoid a noisy + ``TCPStore.recvValue failed`` warning from NCCL's HeartbeatMonitor + on non-zero ranks: + + 1. Remote engines' ``destroy()`` runs first so that every rank calls + ``dist.destroy_process_group()`` after a CPU barrier. This + guarantees all ranks finish NCCL abort together before any store + shuts down. + 2. Workers are killed in reverse rank order so that rank-0 (owner + of the global TCPStore server) receives SIGTERM last. This + avoids the short window where non-zero ranks' HeartbeatMonitor + threads poll a store whose TCP listener has already been closed. """ logger.info("Destroying TrainController...") @@ -421,17 +434,28 @@ async def _destroy_all_engines(): ) for rank, worker in enumerate(self.workers) ] - await asyncio.gather(*tasks, return_exceptions=True) - - run_async_task(_destroy_all_engines) + return await asyncio.gather(*tasks, return_exceptions=True) + + results = run_async_task(_destroy_all_engines) + # Surface per-worker failures instead of silently swallowing them. + for rank, res in enumerate(results or []): + if isinstance(res, BaseException): + logger.warning( + f"Engine destroy on rank {rank} raised " + f"{type(res).__name__}: {res}" + ) logger.info("Engines destroyed") except Exception as e: logger.error(f"Error destroying engines: {e}") - # Then delete workers via scheduler + # Then delete workers via scheduler. Pass reverse_order=True so + # that rank-0 (TCPStore owner) is killed last. All in-tree + # Scheduler implementations (Local/Ray/Slurm) accept this kwarg; + # third-party subclasses that override ``delete_workers`` must + # adopt the same signature. try: - logger.info("Deleting all workers...") - self.scheduler.delete_workers(role=self._worker_role) + logger.info("Deleting all workers (reverse rank order)...") + self.scheduler.delete_workers(role=self._worker_role, reverse_order=True) logger.info("Workers deleted") except Exception as e: logger.error(f"Error deleting workers: {e}") diff --git a/areal/infra/scheduler/local.py b/areal/infra/scheduler/local.py index 7b5688a9ab..1de36891bc 100644 --- a/areal/infra/scheduler/local.py +++ b/areal/infra/scheduler/local.py @@ -1082,23 +1082,27 @@ def _check_worker_health(self, role: str): stderr, ) - def delete_workers(self, role: str | None = None): + def delete_workers(self, role: str | None = None, reverse_order: bool = False): """Delete workers and clean up resources. Parameters ---------- role : str, optional Specific worker role to delete, or None to delete all + reverse_order : bool, optional + If True, terminate workers in reverse rank order so that rank-0 + (owner of the global TCPStore) is signalled last. See + ``Scheduler.delete_workers`` for background. """ if role is None: # Delete colocated roles first (they don't own processes) colocated_roles = list(self._colocated_roles.keys()) for r in colocated_roles: - self.delete_workers(r) + self.delete_workers(r, reverse_order=reverse_order) # Then delete actual worker roles roles = list(self._workers.keys()) for r in roles: - self.delete_workers(r) + self.delete_workers(r, reverse_order=reverse_order) return # Handle colocated/forked role @@ -1107,6 +1111,8 @@ def delete_workers(self, role: str | None = None): if role in self._workers: logger.info(f"Removing forked role '{role}' (managed by parent worker)") workers = self._workers[role] + if reverse_order: + workers = list(reversed(workers)) self._cleanup_workers( workers ) # Release ports, but process=None skips kill @@ -1124,6 +1130,8 @@ def delete_workers(self, role: str | None = None): workers = self._workers[role] logger.info(f"Deleting {len(workers)} workers for role '{role}'") + if reverse_order: + workers = list(reversed(workers)) self._cleanup_workers(workers) del self._workers[role] @@ -1131,15 +1139,68 @@ def delete_workers(self, role: str | None = None): logger.info(f"Successfully deleted workers for role '{role}'") def _cleanup_workers(self, workers: list[WorkerInfo]): + """Tear down a batch of workers with coordinated teardown semantics. + + The previous implementation iterated ``workers`` serially and called + ``kill_process_tree(..., timeout=3, graceful=True)`` on each one. + Because that helper blocks for up to ``timeout`` seconds between + SIGTERM and the fallback SIGKILL, a 4-rank job could spend ~12 s + killing workers one-by-one. During that window only a single rank + was executing its ``engine.destroy()`` path, so the CPU barrier + added in ``FSDPEngine.destroy()`` could never actually synchronise + -- every rank timed out on its barrier and the NCCL teardown race + that produced ``TCPStore.recvValue failed`` / HeartbeatMonitor + warnings was not fixed. + + The corrected behaviour is: + + 1. Release port allocations synchronously (cheap, no I/O). + 2. Send SIGTERM to every worker in the order provided by the + caller, with no blocking waits in between. ``delete_workers`` + passes the list in reverse rank order when + ``reverse_order=True``, which preserves the "rank-0 signalled + last" guarantee while keeping the dispatch window in the + millisecond range. + 3. Wait for every worker to exit in parallel using one thread per + worker. Each thread re-uses ``kill_process_tree`` so the + existing SIGTERM -> wait -> SIGKILL escalation is preserved + per-worker; we just no longer serialise the waits. + + With this change every rank enters ``engine.destroy()`` within the + same small window, the CPU ``dist.barrier`` inside can actually + rendezvous, and the NCCL / TCPStore teardown becomes race-free. + """ + import threading + + # Phase 1: always release ports, regardless of whether the worker + # owns a process (forked workers have ``process is None``). + live_workers: list[WorkerInfo] = [] for worker_info in workers: try: for port_str in worker_info.worker.worker_ports: self._allocated_ports.discard(int(port_str)) + except Exception as e: + logger.error( + f"Error releasing ports for worker {worker_info.worker.id}: {e}", + exc_info=True, + ) + if worker_info.process is not None: + live_workers.append(worker_info) + else: + logger.debug(f"Cleaned up worker {worker_info.worker.id}") - # Only kill process if we own it (non-forked workers) - if worker_info.process is not None: - kill_process_tree(worker_info.process.pid, timeout=3, graceful=True) + if not live_workers: + return + # Phase 2: dispatch SIGTERM to every worker concurrently via + # background threads so that all ranks reach their teardown + # barrier within the same window. The list order is preserved as + # thread start order: when the caller requests reverse_order, + # rank-0 is the last thread to be started, which keeps the + # "rank-0 dies last" property while staying non-blocking. + def _finalize(worker_info: WorkerInfo) -> None: + try: + kill_process_tree(worker_info.process.pid, timeout=3, graceful=True) logger.debug(f"Cleaned up worker {worker_info.worker.id}") except Exception as e: logger.error( @@ -1147,6 +1208,30 @@ def _cleanup_workers(self, workers: list[WorkerInfo]): exc_info=True, ) + threads: list[threading.Thread] = [] + for worker_info in live_workers: + t = threading.Thread( + target=_finalize, + args=(worker_info,), + name=f"cleanup-{worker_info.worker.id}", + daemon=True, + ) + t.start() + threads.append(t) + + # Phase 3: wait for every cleanup thread. Each ``kill_process_tree`` + # call internally waits up to ``timeout=3`` seconds for graceful + # shutdown and then SIGKILLs stragglers, so a small safety margin + # on ``join`` is sufficient. + join_timeout = 10.0 + for t in threads: + t.join(timeout=join_timeout) + if t.is_alive(): + logger.warning( + f"Cleanup thread {t.name} did not finish within " + f"{join_timeout}s; leaving it as daemon." + ) + def _read_log_tail(self, log_file: str, lines: int = 50) -> str: try: with open(log_file) as f: diff --git a/areal/infra/scheduler/ray.py b/areal/infra/scheduler/ray.py index 448b3440ae..a39f5027d2 100644 --- a/areal/infra/scheduler/ray.py +++ b/areal/infra/scheduler/ray.py @@ -343,16 +343,60 @@ def _cleanup_forked_workers(self, workers: list[RayWorkerInfo]): Unlike _cleanup_workers, this doesn't remove placement groups since forked workers share placement groups with target workers. + + Teardown is done in two phases so that peer ranks can finish their + pre-destroy CPU barrier inside ``engine.destroy()`` before any actor + process is forcibly killed: + + 1. Dispatch ``actor.destroy.remote()`` on every actor concurrently + and collect the ObjectRefs (fire but *don't* forget). + 2. ``ray.wait`` on all of them with a bounded timeout so that all + ranks return together. Only then do we drop references / kill + stragglers. """ + # Phase 1: concurrently dispatch destroy on all actors. + destroy_refs: list[tuple[RayWorkerInfo, Any]] = [] for wi in workers: - actor = wi.actor try: - actor.destroy.remote() + ref = wi.actor.destroy.remote() + destroy_refs.append((wi, ref)) except Exception: logger.warning( - f"Could not destroy forked actor {actor}, force killing actor" + f"Could not dispatch destroy on forked actor {wi.actor}, " + f"force killing actor" ) - ray.kill(actor, no_restart=True) + ray.kill(wi.actor, no_restart=True) + + # Phase 2: wait for all destroys to finish (bounded). This lets the + # engine-side pre-destroy CPU barrier complete on every rank before + # we release references. + if destroy_refs: + refs = [r for _, r in destroy_refs] + try: + ray.wait(refs, num_returns=len(refs), timeout=30.0) + except Exception as e: + logger.warning(f"ray.wait on forked destroy refs failed: {e}") + + # Surface per-actor failures; force-kill any that did not finish. + for wi, ref in destroy_refs: + try: + ray.get(ref, timeout=0) + except ray.exceptions.GetTimeoutError: + logger.warning( + f"Forked actor {wi.actor} did not finish destroy in time, " + f"force killing" + ) + try: + ray.kill(wi.actor, no_restart=True) + except Exception: + pass + except Exception as e: + logger.warning( + f"Forked actor {wi.actor} destroy raised " + f"{type(e).__name__}: {e}" + ) + + for wi in workers: # Remove from worker_info_by_id self._worker_info_by_id.pop(wi.worker.id, None) @@ -492,7 +536,7 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: return [wi.worker for wi in worker_info_list] - def delete_workers(self, role: str | None = None): + def delete_workers(self, role: str | None = None, reverse_order: bool = False): """ Delete workers and clean up resources @@ -500,16 +544,20 @@ def delete_workers(self, role: str | None = None): -------- role: str, optional Specific worker role to delete, or None to delete all + reverse_order: bool, optional + If True, iterate workers in reverse rank order when issuing + ``actor.destroy.remote()`` so that rank-0 is signalled last. + Note: Ray kills are asynchronous, so ordering here is best-effort. """ if role is None: # Delete colocated roles first (they're just mappings) colocated_roles = list(self._colocated_roles.keys()) for r in colocated_roles: - self.delete_workers(r) + self.delete_workers(r, reverse_order=reverse_order) # Then delete actual worker roles roles = list(self._workers.keys()) for r in roles: - self.delete_workers(r) + self.delete_workers(r, reverse_order=reverse_order) return # Handle colocated role @@ -521,6 +569,8 @@ def delete_workers(self, role: str | None = None): logger.info( f"Cleaning up {len(workers)} forked actors for role '{role}'" ) + if reverse_order: + workers = list(reversed(workers)) self._cleanup_forked_workers(workers) del self._workers[role] else: @@ -536,6 +586,8 @@ def delete_workers(self, role: str | None = None): workers = self._workers[role] logger.info(f"Deleting {len(workers)} workers for role '{role}'") + if reverse_order: + workers = list(reversed(workers)) self._cleanup_workers(workers) del self._workers[role] @@ -575,22 +627,72 @@ def fork_workers( return worker_ids def _cleanup_workers(self, workers: list[RayWorkerInfo]): - # Kill actors first + """Tear down actors and their placement groups in three phases. + + The ordering matters for distributed teardown correctness: + + 1. Dispatch ``actor.destroy.remote()`` on every actor concurrently + and collect the ObjectRefs. ``destroy`` on the worker side runs + the engine's pre-destroy CPU barrier + ``dist.destroy_process_group``, + which requires all peer ranks to still be alive. + 2. ``ray.wait`` on all destroy refs with a bounded timeout so that + every rank finishes the barrier together. Without this, rank-0 + (TCPStore owner) may be torn down first and cause a noisy + ``TCPStore.recvValue failed`` on other ranks. + 3. Only after the barrier phase, remove the placement groups. PG + removal hard-kills any still-alive actor process, so it must + come last. + """ + # Phase 1: concurrently dispatch destroy on all actors. + destroy_refs: list[tuple[RayWorkerInfo, Any]] = [] for wi in workers: - actor = wi.actor try: - # Asynchronously destroy actor - actor.destroy.remote() + ref = wi.actor.destroy.remote() + destroy_refs.append((wi, ref)) except Exception: try: - actor.__ray_terminate__.remote() + wi.actor.__ray_terminate__.remote() except Exception: logger.warning( - f"Could not destroy remote actor {actor}, force killing actor" + f"Could not destroy remote actor {wi.actor}, " + f"force killing actor" + ) + ray.kill(wi.actor, no_restart=True) + + # Phase 2: wait for destroys to finish so the engine-side CPU + # barrier has a chance to complete on every rank. + if destroy_refs: + ref_to_wi = {id(r): wi for wi, r in destroy_refs} + refs = [r for _, r in destroy_refs] + + ready_refs, remaining_refs = ray.wait( + refs, num_returns=len(refs), timeout=30.0 + ) + + # Completed: check whether destroy raised an exception. + for ref in ready_refs: + wi = ref_to_wi[id(ref)] + try: + ray.get(ref) + except Exception as e: + logger.warning( + f"Actor {wi.actor} destroy raised {type(e).__name__}: {e}" ) - ray.kill(actor, no_restart=True) - # Collect unique placement groups and remove them + # Timed-out: force kill actors that did not finish in time. + for ref in remaining_refs: + wi = ref_to_wi[id(ref)] + logger.warning( + f"Actor {wi.actor} did not finish destroy in 30s, force killing" + ) + try: + ray.kill(wi.actor, no_restart=True) + except Exception: + pass + + # Phase 3: collect unique placement groups and remove them. + # This step hard-kills any actor still using the PG, so it MUST + # come after the barrier phase above. unique_pgs = {wi.placement_group for wi in workers} for pg in unique_pgs: try: diff --git a/areal/infra/scheduler/slurm.py b/areal/infra/scheduler/slurm.py index b5987bd055..16be4029c9 100644 --- a/areal/infra/scheduler/slurm.py +++ b/areal/infra/scheduler/slurm.py @@ -1221,14 +1221,92 @@ def get_workers(self, role: str, timeout: float | None = None) -> list[Worker]: raise WorkerTimeoutError(role, timeout) - def delete_workers(self, role: str | None = None): + def _destroy_engines_on_workers( + self, workers: list[SlurmWorkerInfo], timeout: float = 30.0 + ) -> None: + """Call ``engine.destroy()`` on every worker via HTTP before killing jobs. + + All calls are dispatched concurrently so that the engine-side CPU + barrier (``dist.barrier`` + ``dist.destroy_process_group``) can + complete across all ranks. A bounded *timeout* prevents indefinite + blocking when a worker is already unreachable. + """ + if not workers: + return + + async def _destroy_all(): + destroy_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession( + timeout=destroy_timeout, + connector=get_default_connector(), + ) as session: + tasks = [] + for wi in workers: + port = int(wi.worker.worker_ports[0]) + url = f"http://{format_hostport(wi.worker.ip, port)}/call" + payload = { + "method": "destroy", + "engine_name": wi.worker.id, + "args": serialize_value([]), + "kwargs": serialize_value({}), + "rpc_meta": None, + } + tasks.append( + session.post( + url, + data=orjson.dumps(payload), + headers={"Content-Type": "application/json"}, + ) + ) + results = await asyncio.gather( + *[self._safe_destroy_request(t) for t in tasks], + return_exceptions=True, + ) + for wi, res in zip(workers, results): + if isinstance(res, BaseException): + logger.warning( + f"engine.destroy() on {wi.worker.id} failed: " + f"{type(res).__name__}: {res}" + ) + + try: + run_async_task(_destroy_all) + except Exception as e: + logger.warning(f"Failed to destroy engines before cancel: {e}") + + @staticmethod + async def _safe_destroy_request(coro): + """Await an aiohttp context-manager response, suppressing errors.""" + try: + async with coro as resp: + await resp.read() + except Exception as e: + raise RuntimeError(str(e)) from e + + def delete_workers(self, role: str | None = None, reverse_order: bool = False): """Delete workers and cancel Slurm jobs. + Teardown follows a two-phase protocol analogous to the Ray and Local + schedulers: + + 1. **Engine destroy** – call ``engine.destroy()`` on every worker via + HTTP concurrently. This runs the engine-side CPU barrier and + ``dist.destroy_process_group`` so that NCCL communicators and the + TCPStore are shut down cleanly on all ranks. + 2. **Job cancel** – ``scancel`` the Slurm job. At this point process + groups are already torn down, so killing the processes will not + produce spurious ``TCPStore.recvValue failed`` warnings. + Parameters ---------- role : str, optional Role to delete. If None, deletes all roles. + reverse_order : bool, optional + Accepted for API compatibility with other schedulers but ignored + here: Slurm tears down the entire job step atomically via + ``scancel``, so per-rank ordering cannot be enforced. """ + del reverse_order # unused, see docstring if role is None: # Delete colocated/forked roles first (they don't own Slurm jobs) colocated_roles = list(self._colocated_roles.keys()) @@ -1261,9 +1339,17 @@ def delete_workers(self, role: str | None = None): del self._workers[role] return - logger.info(f"Deleting workers for role '{role}' (job ID {job_id})") + workers = self._workers[role] + logger.info( + f"Deleting {len(workers)} workers for role '{role}' (job ID {job_id})" + ) + + # Phase 1: destroy engines so that the CPU barrier and + # dist.destroy_process_group complete on every rank. + self._destroy_engines_on_workers(workers) - # Cancel Slurm job + # Phase 2: cancel the Slurm job. Process groups are already torn + # down, so scancel will not cause TCPStore race conditions. try: cancel_jobs(slurm_ids=[job_id], signal="SIGTERM") time.sleep(2) # Give time for graceful shutdown diff --git a/tests/test_local_scheduler.py b/tests/test_local_scheduler.py index c910750940..5af41cd2fc 100644 --- a/tests/test_local_scheduler.py +++ b/tests/test_local_scheduler.py @@ -1100,6 +1100,39 @@ def test_delete_workers_nonexistent_role(self, scheduler): # Should not raise scheduler.delete_workers("nonexistent") + def test_delete_workers_reverse_order(self, scheduler, tmp_path, monkeypatch): + """With reverse_order=True, workers are cleaned up in reverse rank order. + + This protects rank-0 (owner of the global TCPStore server) from being + torn down before non-zero ranks finish their final NCCL abort. + """ + workers = [ + create_worker_info( + worker_id=f"role1/{i}", + role="role1", + ports=[str(8000 + i)], + log_file=str(tmp_path / f"role1-{i}.log"), + ) + for i in range(4) + ] + scheduler._workers["role1"] = workers + scheduler._allocated_ports = {8000, 8001, 8002, 8003} + + observed_order: list[str] = [] + + original_cleanup = scheduler._cleanup_workers + + def spy(workers_arg): + observed_order.extend(w.worker.id for w in workers_arg) + original_cleanup(workers_arg) + + monkeypatch.setattr(scheduler, "_cleanup_workers", spy) + + scheduler.delete_workers("role1", reverse_order=True) + + assert observed_order == ["role1/3", "role1/2", "role1/1", "role1/0"] + assert "role1" not in scheduler._workers + def test_cleanup_workers_releases_ports(self, scheduler, tmp_path): """Should release allocated ports when cleaning up workers.""" worker = create_worker_info( diff --git a/tests/test_rollout_controller.py b/tests/test_rollout_controller.py index 09bb817f58..749a4d193d 100644 --- a/tests/test_rollout_controller.py +++ b/tests/test_rollout_controller.py @@ -161,7 +161,7 @@ async def _async_call_engine_internal(self, worker_id, method, *args, **kwargs): await asyncio.sleep(0.001) return None - def delete_workers(self, role): + def delete_workers(self, role, reverse_order: bool = False): self.workers.clear() self._pending_results.clear() self._task_counter = 0 diff --git a/tests/test_train_controller.py b/tests/test_train_controller.py index 6ad6c64f55..661ebedc6f 100644 --- a/tests/test_train_controller.py +++ b/tests/test_train_controller.py @@ -97,9 +97,10 @@ async def async_call_engine(self, worker_id, method, *args, **kwargs): await asyncio.sleep(0.001) return None - def delete_workers(self, role): + def delete_workers(self, role, reverse_order: bool = False): """Mock worker deletion.""" self.deleted_roles.append(role) + self.delete_reverse_order = reverse_order self.workers.clear() @@ -260,6 +261,22 @@ def raise_error(role): # Workers should still be cleared assert len(train_controller.workers) == 0 + def test_destroy_requests_reverse_order(self, train_controller, ft_spec): + """Workers must be torn down in reverse rank order. + + This protects rank-0 (which owns the global TCPStore server) from + being killed before non-zero ranks finish NCCL abort, avoiding a + noisy ``TCPStore.recvValue failed`` warning from + HeartbeatMonitor. + """ + train_controller.initialize(role="train_worker", ft_spec=ft_spec) + + train_controller.destroy() + + assert ( + getattr(train_controller.scheduler, "delete_reverse_order", False) is True + ) + class TestTrainControllerMergeResults: """Tests for result merging via _merge_tensors."""