-
Notifications
You must be signed in to change notification settings - Fork 493
fix: teardown tcpstore race #1244
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We may also want to fix @areal/experimental/engine/archon_engine.py except for fsdp and megatron
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -510,6 +510,19 @@ def destroy(self): | |
| # handles still exist and we expect another engine to | ||
| # clean up these groups. | ||
| if dist.is_initialized() and self.own_global_group: | ||
| # Pre-destroy synchronization on a CPU (gloo) group so that all | ||
| # ranks leave the NCCL collective phase together. Without this | ||
| # barrier, rank-0 (which owns the TCPStore server) may exit | ||
| # before peers finish their final NCCL abort, causing | ||
| # HeartbeatMonitor background threads on other ranks to observe | ||
| # "recvValue failed" on the already-closed store. | ||
| if getattr(self, "_cpu_group", None) is not None: | ||
| try: | ||
| dist.barrier(group=self._cpu_group) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| except Exception as e: # pragma: no cover - best-effort | ||
| self.logger.warning( | ||
| f"pre-destroy CPU barrier failed (ignored): {e}" | ||
| ) | ||
| mpu.destroy_model_parallel() | ||
| dist.destroy_process_group() | ||
| self.own_global_group = False | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1082,23 +1082,27 @@ def _check_worker_health(self, role: str): | |||||||||||||||||||||||||||||||||
| stderr, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def delete_workers(self, role: str | None = None): | ||||||||||||||||||||||||||||||||||
| def delete_workers(self, role: str | None = None, reverse_order: bool = False): | ||||||||||||||||||||||||||||||||||
| """Delete workers and clean up resources. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||||||||||||
| role : str, optional | ||||||||||||||||||||||||||||||||||
| Specific worker role to delete, or None to delete all | ||||||||||||||||||||||||||||||||||
| reverse_order : bool, optional | ||||||||||||||||||||||||||||||||||
| If True, terminate workers in reverse rank order so that rank-0 | ||||||||||||||||||||||||||||||||||
| (owner of the global TCPStore) is signalled last. See | ||||||||||||||||||||||||||||||||||
| ``Scheduler.delete_workers`` for background. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| if role is None: | ||||||||||||||||||||||||||||||||||
| # Delete colocated roles first (they don't own processes) | ||||||||||||||||||||||||||||||||||
| colocated_roles = list(self._colocated_roles.keys()) | ||||||||||||||||||||||||||||||||||
| for r in colocated_roles: | ||||||||||||||||||||||||||||||||||
| self.delete_workers(r) | ||||||||||||||||||||||||||||||||||
| self.delete_workers(r, reverse_order=reverse_order) | ||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When |
||||||||||||||||||||||||||||||||||
| # Then delete actual worker roles | ||||||||||||||||||||||||||||||||||
| roles = list(self._workers.keys()) | ||||||||||||||||||||||||||||||||||
| for r in roles: | ||||||||||||||||||||||||||||||||||
| self.delete_workers(r) | ||||||||||||||||||||||||||||||||||
| self.delete_workers(r, reverse_order=reverse_order) | ||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Handle colocated/forked role | ||||||||||||||||||||||||||||||||||
|
|
@@ -1107,6 +1111,8 @@ def delete_workers(self, role: str | None = None): | |||||||||||||||||||||||||||||||||
| if role in self._workers: | ||||||||||||||||||||||||||||||||||
| logger.info(f"Removing forked role '{role}' (managed by parent worker)") | ||||||||||||||||||||||||||||||||||
| workers = self._workers[role] | ||||||||||||||||||||||||||||||||||
| if reverse_order: | ||||||||||||||||||||||||||||||||||
| workers = list(reversed(workers)) | ||||||||||||||||||||||||||||||||||
| self._cleanup_workers( | ||||||||||||||||||||||||||||||||||
| workers | ||||||||||||||||||||||||||||||||||
| ) # Release ports, but process=None skips kill | ||||||||||||||||||||||||||||||||||
|
|
@@ -1124,29 +1130,108 @@ def delete_workers(self, role: str | None = None): | |||||||||||||||||||||||||||||||||
| workers = self._workers[role] | ||||||||||||||||||||||||||||||||||
| logger.info(f"Deleting {len(workers)} workers for role '{role}'") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if reverse_order: | ||||||||||||||||||||||||||||||||||
| workers = list(reversed(workers)) | ||||||||||||||||||||||||||||||||||
| self._cleanup_workers(workers) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| del self._workers[role] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| logger.info(f"Successfully deleted workers for role '{role}'") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _cleanup_workers(self, workers: list[WorkerInfo]): | ||||||||||||||||||||||||||||||||||
| """Tear down a batch of workers with coordinated teardown semantics. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| The previous implementation iterated ``workers`` serially and called | ||||||||||||||||||||||||||||||||||
| ``kill_process_tree(..., timeout=3, graceful=True)`` on each one. | ||||||||||||||||||||||||||||||||||
| Because that helper blocks for up to ``timeout`` seconds between | ||||||||||||||||||||||||||||||||||
| SIGTERM and the fallback SIGKILL, a 4-rank job could spend ~12 s | ||||||||||||||||||||||||||||||||||
| killing workers one-by-one. During that window only a single rank | ||||||||||||||||||||||||||||||||||
| was executing its ``engine.destroy()`` path, so the CPU barrier | ||||||||||||||||||||||||||||||||||
| added in ``FSDPEngine.destroy()`` could never actually synchronise | ||||||||||||||||||||||||||||||||||
| -- every rank timed out on its barrier and the NCCL teardown race | ||||||||||||||||||||||||||||||||||
| that produced ``TCPStore.recvValue failed`` / HeartbeatMonitor | ||||||||||||||||||||||||||||||||||
| warnings was not fixed. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| The corrected behaviour is: | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| 1. Release port allocations synchronously (cheap, no I/O). | ||||||||||||||||||||||||||||||||||
| 2. Send SIGTERM to every worker in the order provided by the | ||||||||||||||||||||||||||||||||||
| caller, with no blocking waits in between. ``delete_workers`` | ||||||||||||||||||||||||||||||||||
| passes the list in reverse rank order when | ||||||||||||||||||||||||||||||||||
| ``reverse_order=True``, which preserves the "rank-0 signalled | ||||||||||||||||||||||||||||||||||
| last" guarantee while keeping the dispatch window in the | ||||||||||||||||||||||||||||||||||
| millisecond range. | ||||||||||||||||||||||||||||||||||
| 3. Wait for every worker to exit in parallel using one thread per | ||||||||||||||||||||||||||||||||||
| worker. Each thread re-uses ``kill_process_tree`` so the | ||||||||||||||||||||||||||||||||||
| existing SIGTERM -> wait -> SIGKILL escalation is preserved | ||||||||||||||||||||||||||||||||||
| per-worker; we just no longer serialise the waits. | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| With this change every rank enters ``engine.destroy()`` within the | ||||||||||||||||||||||||||||||||||
| same small window, the CPU ``dist.barrier`` inside can actually | ||||||||||||||||||||||||||||||||||
| rendezvous, and the NCCL / TCPStore teardown becomes race-free. | ||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||
| import threading | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Phase 1: always release ports, regardless of whether the worker | ||||||||||||||||||||||||||||||||||
| # owns a process (forked workers have ``process is None``). | ||||||||||||||||||||||||||||||||||
| live_workers: list[WorkerInfo] = [] | ||||||||||||||||||||||||||||||||||
| for worker_info in workers: | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| for port_str in worker_info.worker.worker_ports: | ||||||||||||||||||||||||||||||||||
| self._allocated_ports.discard(int(port_str)) | ||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||
| f"Error releasing ports for worker {worker_info.worker.id}: {e}", | ||||||||||||||||||||||||||||||||||
| exc_info=True, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| if worker_info.process is not None: | ||||||||||||||||||||||||||||||||||
| live_workers.append(worker_info) | ||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| logger.debug(f"Cleaned up worker {worker_info.worker.id}") | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Only kill process if we own it (non-forked workers) | ||||||||||||||||||||||||||||||||||
| if worker_info.process is not None: | ||||||||||||||||||||||||||||||||||
| kill_process_tree(worker_info.process.pid, timeout=3, graceful=True) | ||||||||||||||||||||||||||||||||||
| if not live_workers: | ||||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Phase 2: dispatch SIGTERM to every worker concurrently via | ||||||||||||||||||||||||||||||||||
| # background threads so that all ranks reach their teardown | ||||||||||||||||||||||||||||||||||
| # barrier within the same window. The list order is preserved as | ||||||||||||||||||||||||||||||||||
| # thread start order: when the caller requests reverse_order, | ||||||||||||||||||||||||||||||||||
| # rank-0 is the last thread to be started, which keeps the | ||||||||||||||||||||||||||||||||||
| # "rank-0 dies last" property while staying non-blocking. | ||||||||||||||||||||||||||||||||||
| def _finalize(worker_info: WorkerInfo) -> None: | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| kill_process_tree(worker_info.process.pid, timeout=3, graceful=True) | ||||||||||||||||||||||||||||||||||
| logger.debug(f"Cleaned up worker {worker_info.worker.id}") | ||||||||||||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||||||||||||
| logger.error( | ||||||||||||||||||||||||||||||||||
| f"Error cleaning up worker {worker_info.worker.id}: {e}", | ||||||||||||||||||||||||||||||||||
| exc_info=True, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| threads: list[threading.Thread] = [] | ||||||||||||||||||||||||||||||||||
| for worker_info in live_workers: | ||||||||||||||||||||||||||||||||||
| t = threading.Thread( | ||||||||||||||||||||||||||||||||||
| target=_finalize, | ||||||||||||||||||||||||||||||||||
| args=(worker_info,), | ||||||||||||||||||||||||||||||||||
| name=f"cleanup-{worker_info.worker.id}", | ||||||||||||||||||||||||||||||||||
| daemon=True, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| t.start() | ||||||||||||||||||||||||||||||||||
| threads.append(t) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| # Phase 3: wait for every cleanup thread. Each ``kill_process_tree`` | ||||||||||||||||||||||||||||||||||
| # call internally waits up to ``timeout=3`` seconds for graceful | ||||||||||||||||||||||||||||||||||
| # shutdown and then SIGKILLs stragglers, so a small safety margin | ||||||||||||||||||||||||||||||||||
| # on ``join`` is sufficient. | ||||||||||||||||||||||||||||||||||
| join_timeout = 10.0 | ||||||||||||||||||||||||||||||||||
| for t in threads: | ||||||||||||||||||||||||||||||||||
| t.join(timeout=join_timeout) | ||||||||||||||||||||||||||||||||||
| if t.is_alive(): | ||||||||||||||||||||||||||||||||||
| logger.warning( | ||||||||||||||||||||||||||||||||||
| f"Cleanup thread {t.name} did not finish within " | ||||||||||||||||||||||||||||||||||
| f"{join_timeout}s; leaving it as daemon." | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+1226
to
+1233
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Joining threads sequentially with a fixed timeout per thread can lead to significant delays if multiple workers hang (e.g.,
Suggested change
|
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| def _read_log_tail(self, log_file: str, lines: int = 50) -> str: | ||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||
| with open(log_file) as f: | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CPU barrier uses the default process group timeout (typically 30 minutes). If a rank has crashed or is unresponsive, this will cause all other ranks to hang for a long time during teardown. Consider using a shorter timeout for this specific synchronization point to improve robustness during failure scenarios.