Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion areal/api/scheduler_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,22 @@ def get_workers(self, role: str, timeout: int | None = None) -> list[Worker]:
raise NotImplementedError()

@abc.abstractmethod
def delete_workers(self, role: str | None = None):
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
"""Stop and clean up worker processes.

Parameters
----------
role : str, optional
Specific role to delete. If None, all workers are deleted
reverse_order : bool, optional
If True, terminate workers in reverse order of their IDs so that
rank-0 (which typically owns the global TCPStore server) is the
last one to be killed. This helps avoid a noisy
``TCPStore.recvValue failed`` warning emitted by NCCL's
HeartbeatMonitor background thread on non-zero ranks during
teardown. Implementations that tear down all workers as a single
atomic operation (e.g. ``scancel`` for Slurm) may safely ignore
this argument. Defaults to False for backward compatibility.

Raises
------
Expand Down
18 changes: 18 additions & 0 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,25 @@ def destroy(self):
# handles still exist and we expect another engine to
# clean up these groups.
if dist.is_initialized() and self.own_global_group:
# Pre-destroy synchronization on a CPU (gloo) group so that all
# ranks leave the NCCL collective phase together. Without this
# barrier, rank-0 (which owns the TCPStore server) may exit
# before peers finish their final NCCL abort, causing
# HeartbeatMonitor background threads on other ranks to observe
# "recvValue failed" on the already-closed store. This is
# harmless but produces a noisy stderr backtrace at teardown.
if getattr(self, "_cpu_group", None) is not None:
try:
dist.barrier(group=self._cpu_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The CPU barrier uses the default process group timeout (typically 30 minutes). If a rank has crashed or is unresponsive, this will cause all other ranks to hang for a long time during teardown. Consider using a shorter timeout for this specific synchronization point to improve robustness during failure scenarios.

except Exception as e: # pragma: no cover - best-effort
self.logger.warning(
f"pre-destroy CPU barrier failed (ignored): {e}"
)
dist.destroy_process_group()
# Make destroy() idempotent: if the controller calls destroy
# more than once (e.g. via cleanup hooks), the second call
# must not try to destroy already-destroyed groups.
self.own_global_group = False

@property
def initialized(self) -> bool:
Expand Down
13 changes: 13 additions & 0 deletions areal/engine/megatron_engine.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also want to fix @areal/experimental/engine/archon_engine.py except for fsdp and megatron

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,19 @@ def destroy(self):
# handles still exist and we expect another engine to
# clean up these groups.
if dist.is_initialized() and self.own_global_group:
# Pre-destroy synchronization on a CPU (gloo) group so that all
# ranks leave the NCCL collective phase together. Without this
# barrier, rank-0 (which owns the TCPStore server) may exit
# before peers finish their final NCCL abort, causing
# HeartbeatMonitor background threads on other ranks to observe
# "recvValue failed" on the already-closed store.
if getattr(self, "_cpu_group", None) is not None:
try:
dist.barrier(group=self._cpu_group)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the FSDP engine, this CPU barrier uses the default process group timeout. An unresponsive rank could cause a significant hang during teardown. A shorter timeout for this barrier would be safer.

except Exception as e: # pragma: no cover - best-effort
self.logger.warning(
f"pre-destroy CPU barrier failed (ignored): {e}"
)
mpu.destroy_model_parallel()
dist.destroy_process_group()
self.own_global_group = False
Expand Down
18 changes: 18 additions & 0 deletions areal/experimental/engine/archon_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,25 @@ def destroy(self):
gc.collect()

if dist.is_initialized() and self.own_global_group:
# Pre-destroy synchronization on a CPU (gloo) group so that all
# ranks leave the NCCL collective phase together. Without this
# barrier, rank-0 (which owns the TCPStore server) may exit
# before peers finish their final NCCL abort, causing
# HeartbeatMonitor background threads on other ranks to observe
# "recvValue failed" on the already-closed store. This is
# harmless but produces a noisy stderr backtrace at teardown.
if getattr(self, "_cpu_group", None) is not None:
try:
dist.barrier(group=self._cpu_group)
except Exception as e:
self.logger.warning(
f"pre-destroy CPU barrier failed (ignored): {e}"
)
dist.destroy_process_group()
# Make destroy() idempotent: if the controller calls destroy
# more than once (e.g. via cleanup hooks), the second call
# must not try to destroy already-destroyed groups.
self.own_global_group = False
self._initialized = False

def train(self, mode: bool = True):
Expand Down
36 changes: 30 additions & 6 deletions areal/infra/controller/train_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,19 @@ def destroy(self):
"""Destroy the controller and release GPU memory of models.

Cleans up all resources including workers, engines, and internal state.

The teardown order is carefully chosen to avoid a noisy
``TCPStore.recvValue failed`` warning from NCCL's HeartbeatMonitor
on non-zero ranks:

1. Remote engines' ``destroy()`` runs first so that every rank calls
``dist.destroy_process_group()`` after a CPU barrier. This
guarantees all ranks finish NCCL abort together before any store
shuts down.
2. Workers are killed in reverse rank order so that rank-0 (owner
of the global TCPStore server) receives SIGTERM last. This
avoids the short window where non-zero ranks' HeartbeatMonitor
threads poll a store whose TCP listener has already been closed.
"""
logger.info("Destroying TrainController...")

Expand All @@ -421,17 +434,28 @@ async def _destroy_all_engines():
)
for rank, worker in enumerate(self.workers)
]
await asyncio.gather(*tasks, return_exceptions=True)

run_async_task(_destroy_all_engines)
return await asyncio.gather(*tasks, return_exceptions=True)

results = run_async_task(_destroy_all_engines)
# Surface per-worker failures instead of silently swallowing them.
for rank, res in enumerate(results or []):
if isinstance(res, BaseException):
logger.warning(
f"Engine destroy on rank {rank} raised "
f"{type(res).__name__}: {res}"
)
logger.info("Engines destroyed")
except Exception as e:
logger.error(f"Error destroying engines: {e}")

# Then delete workers via scheduler
# Then delete workers via scheduler. Pass reverse_order=True so
# that rank-0 (TCPStore owner) is killed last. All in-tree
# Scheduler implementations (Local/Ray/Slurm) accept this kwarg;
# third-party subclasses that override ``delete_workers`` must
# adopt the same signature.
try:
logger.info("Deleting all workers...")
self.scheduler.delete_workers(role=self._worker_role)
logger.info("Deleting all workers (reverse rank order)...")
self.scheduler.delete_workers(role=self._worker_role, reverse_order=True)
logger.info("Workers deleted")
except Exception as e:
logger.error(f"Error deleting workers: {e}")
Expand Down
97 changes: 91 additions & 6 deletions areal/infra/scheduler/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,23 +1082,27 @@ def _check_worker_health(self, role: str):
stderr,
)

def delete_workers(self, role: str | None = None):
def delete_workers(self, role: str | None = None, reverse_order: bool = False):
"""Delete workers and clean up resources.

Parameters
----------
role : str, optional
Specific worker role to delete, or None to delete all
reverse_order : bool, optional
If True, terminate workers in reverse rank order so that rank-0
(owner of the global TCPStore) is signalled last. See
``Scheduler.delete_workers`` for background.
"""
if role is None:
# Delete colocated roles first (they don't own processes)
colocated_roles = list(self._colocated_roles.keys())
for r in colocated_roles:
self.delete_workers(r)
self.delete_workers(r, reverse_order=reverse_order)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When reverse_order is True, the roles themselves should also be iterated in reverse order (e.g., reversed(colocated_roles)). This ensures that roles created earlier, which typically contain the global TCPStore owner (rank-0), are processed last, maintaining the teardown guarantee across the entire cluster rather than just within each role.

# Then delete actual worker roles
roles = list(self._workers.keys())
for r in roles:
self.delete_workers(r)
self.delete_workers(r, reverse_order=reverse_order)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to colocated roles, the actual worker roles should be iterated in reverse order when reverse_order is True to ensure the role containing rank-0 is the last one to be terminated.

return

# Handle colocated/forked role
Expand All @@ -1107,6 +1111,8 @@ def delete_workers(self, role: str | None = None):
if role in self._workers:
logger.info(f"Removing forked role '{role}' (managed by parent worker)")
workers = self._workers[role]
if reverse_order:
workers = list(reversed(workers))
self._cleanup_workers(
workers
) # Release ports, but process=None skips kill
Expand All @@ -1124,29 +1130,108 @@ def delete_workers(self, role: str | None = None):
workers = self._workers[role]
logger.info(f"Deleting {len(workers)} workers for role '{role}'")

if reverse_order:
workers = list(reversed(workers))
self._cleanup_workers(workers)

del self._workers[role]

logger.info(f"Successfully deleted workers for role '{role}'")

def _cleanup_workers(self, workers: list[WorkerInfo]):
"""Tear down a batch of workers with coordinated teardown semantics.

The previous implementation iterated ``workers`` serially and called
``kill_process_tree(..., timeout=3, graceful=True)`` on each one.
Because that helper blocks for up to ``timeout`` seconds between
SIGTERM and the fallback SIGKILL, a 4-rank job could spend ~12 s
killing workers one-by-one. During that window only a single rank
was executing its ``engine.destroy()`` path, so the CPU barrier
added in ``FSDPEngine.destroy()`` could never actually synchronise
-- every rank timed out on its barrier and the NCCL teardown race
that produced ``TCPStore.recvValue failed`` / HeartbeatMonitor
warnings was not fixed.

The corrected behaviour is:

1. Release port allocations synchronously (cheap, no I/O).
2. Send SIGTERM to every worker in the order provided by the
caller, with no blocking waits in between. ``delete_workers``
passes the list in reverse rank order when
``reverse_order=True``, which preserves the "rank-0 signalled
last" guarantee while keeping the dispatch window in the
millisecond range.
3. Wait for every worker to exit in parallel using one thread per
worker. Each thread re-uses ``kill_process_tree`` so the
existing SIGTERM -> wait -> SIGKILL escalation is preserved
per-worker; we just no longer serialise the waits.

With this change every rank enters ``engine.destroy()`` within the
same small window, the CPU ``dist.barrier`` inside can actually
rendezvous, and the NCCL / TCPStore teardown becomes race-free.
"""
import threading

# Phase 1: always release ports, regardless of whether the worker
# owns a process (forked workers have ``process is None``).
live_workers: list[WorkerInfo] = []
for worker_info in workers:
try:
for port_str in worker_info.worker.worker_ports:
self._allocated_ports.discard(int(port_str))
except Exception as e:
logger.error(
f"Error releasing ports for worker {worker_info.worker.id}: {e}",
exc_info=True,
)
if worker_info.process is not None:
live_workers.append(worker_info)
else:
logger.debug(f"Cleaned up worker {worker_info.worker.id}")

# Only kill process if we own it (non-forked workers)
if worker_info.process is not None:
kill_process_tree(worker_info.process.pid, timeout=3, graceful=True)
if not live_workers:
return

# Phase 2: dispatch SIGTERM to every worker concurrently via
# background threads so that all ranks reach their teardown
# barrier within the same window. The list order is preserved as
# thread start order: when the caller requests reverse_order,
# rank-0 is the last thread to be started, which keeps the
# "rank-0 dies last" property while staying non-blocking.
def _finalize(worker_info: WorkerInfo) -> None:
try:
kill_process_tree(worker_info.process.pid, timeout=3, graceful=True)
logger.debug(f"Cleaned up worker {worker_info.worker.id}")
except Exception as e:
logger.error(
f"Error cleaning up worker {worker_info.worker.id}: {e}",
exc_info=True,
)

threads: list[threading.Thread] = []
for worker_info in live_workers:
t = threading.Thread(
target=_finalize,
args=(worker_info,),
name=f"cleanup-{worker_info.worker.id}",
daemon=True,
)
t.start()
threads.append(t)

# Phase 3: wait for every cleanup thread. Each ``kill_process_tree``
# call internally waits up to ``timeout=3`` seconds for graceful
# shutdown and then SIGKILLs stragglers, so a small safety margin
# on ``join`` is sufficient.
join_timeout = 10.0
for t in threads:
t.join(timeout=join_timeout)
if t.is_alive():
logger.warning(
f"Cleanup thread {t.name} did not finish within "
f"{join_timeout}s; leaving it as daemon."
)
Comment on lines +1226 to +1233
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Joining threads sequentially with a fixed timeout per thread can lead to significant delays if multiple workers hang (e.g., $N \times 10$ seconds). Since the cleanup threads are already running in parallel and kill_process_tree has its own internal timeout, it is more efficient to use a global deadline for joining all threads.

Suggested change
join_timeout = 10.0
for t in threads:
t.join(timeout=join_timeout)
if t.is_alive():
logger.warning(
f"Cleanup thread {t.name} did not finish within "
f"{join_timeout}s; leaving it as daemon."
)
join_deadline = time.time() + 10.0
for t in threads:
t.join(timeout=max(0, join_deadline - time.time()))
if t.is_alive():
logger.warning(
f"Cleanup thread {t.name} did not finish within "
f"the timeout; leaving it as daemon."
)


def _read_log_tail(self, log_file: str, lines: int = 50) -> str:
try:
with open(log_file) as f:
Expand Down
Loading
Loading