Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 160 additions & 36 deletions rclpy/rclpy/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@

from collections import deque
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from contextlib import ExitStack
from dataclasses import dataclass
from functools import partial
import inspect
import os
from threading import Condition
from threading import Lock
from threading import RLock
import threading
import time
from types import TracebackType
from typing import Any
Expand Down Expand Up @@ -87,37 +86,107 @@ class _WorkTracker:
"""Track the amount of work that is in progress."""

def __init__(self) -> None:
# Number of tasks that are being executed
self._num_work_executing = 0
self._work_condition = Condition()

def __enter__(self) -> None:
"""Increment the amount of executing work by 1."""
with self._work_condition:
self._num_work_executing += 1

def __exit__(self, exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException], exctb: Optional[TracebackType]) -> None:
"""Decrement the amount of work executing by 1."""
self._work_condition = threading.Condition()
# Per-thread reentrant counter of in-flight work. A thread has
# an entry iff it is currently inside __enter__/__exit__; the
# value is the reentrance depth. The set of keys is the
# set of threads currently running a callback.
self._executing_thread_counts: Dict[threading.Thread, int] = {}
# Threads whose in-flight callback (if any) is committed to
# finishing rather than making further progress -- i.e., the
# thread is parked in wait() or has already returned from
# wait() and is finishing the surrounding shutdown. Their
# callback work should not block other waiters' drain checks;
# otherwise two callbacks on different worker threads both
# calling Executor.shutdown() would deadlock on each other.
# An entry is removed only when the owning callback ends
# (__exit__) or, for external callers with no in-flight
# callback, immediately when their wait() returns.
self._waiting_threads: Set[threading.Thread] = set()

@contextmanager
def track_callback(self) -> Generator[None, None, None]:
"""
Track an in-flight callback for the duration of the context.

The owning thread is captured at enter time and used to decrement
the per-thread count when the context exits -- even if the exit
runs on a different thread. That happens when a coroutine using
this context manager is suspended at an inner ``await`` and then
closed via GC (e.g. during executor teardown): ``coro.close()``
raises ``GeneratorExit`` at the suspension point on whatever
thread the GC happened to run on, and the ``with`` block's
unwinding -- including this finally -- runs on that thread, not
the original worker thread. Using ``threading.current_thread()``
in the finally would either lose the decrement (best case) or
``KeyError`` (current case).
"""
owner = threading.current_thread()
with self._work_condition:
self._num_work_executing -= 1
self._work_condition.notify_all()
self._executing_thread_counts[owner] = (
self._executing_thread_counts.get(owner, 0) + 1)
try:
yield
finally:
with self._work_condition:
count = self._executing_thread_counts[owner] - 1
if count == 0:
del self._executing_thread_counts[owner]
# The thread's callback has ended, so it's no longer
# "committed to finishing" -- drop its waiter
# membership.
self._waiting_threads.discard(owner)
else:
self._executing_thread_counts[owner] = count
self._work_condition.notify_all()

def wait(self, timeout_sec: Optional[float] = None) -> bool:
"""
Wait until all work completes.

Work being executed by the calling thread is excluded from the wait,
since that work is necessarily blocked on this call returning. Work
being executed by any other thread that has itself entered wait() is
also excluded, so concurrent shutdown() calls from inside callbacks
on different worker threads don't deadlock waiting for each other.

:param timeout_sec: Seconds to wait. Block forever if None or negative. Don't wait if 0
:type timeout_sec: float or None
:rtype: bool True if all work completed
"""
if timeout_sec is not None and timeout_sec < 0.0:
timeout_sec = None
# Wait for all work to complete
current = threading.current_thread()

def other_work_drained() -> bool:
# True once every thread with in-flight work is itself in the
# waiting set, i.e., committed to finishing rather than making
# progress on its callback.
return self._executing_thread_counts.keys() <= self._waiting_threads

with self._work_condition:
if not self._work_condition.wait_for(
lambda: self._num_work_executing == 0, timeout_sec):
return False
added_self = current not in self._waiting_threads
if added_self:
self._waiting_threads.add(current)
# A new waiter may have just satisfied an existing
# waiter's condition (its in-flight work is now excluded).
self._work_condition.notify_all()
try:
if not self._work_condition.wait_for(other_work_drained, timeout_sec):
return False
finally:
# Keep the waiter membership while a callback is still
# in flight on this thread -- removing it now would let
# other concurrent waiters' predicates flip back to
# False and re-block until our callback ends, even
# though our callback is committed to finishing (we are
# past wait() and the rest of shutdown is cleanup).
# __exit__ will drop the membership when the callback
# ends. For external callers with no in-flight
# callback, no __exit__ will run, so discard here.
if added_self and current not in self._executing_thread_counts:
self._waiting_threads.discard(current)
self._work_condition.notify_all()
return True


Expand Down Expand Up @@ -212,20 +281,20 @@ def __init__(self, *, context: Optional[Context] = None) -> None:
super().__init__()
self._context = get_default_context() if context is None else context
self._nodes: Set[Node] = set()
self._nodes_lock = RLock()
self._nodes_lock = threading.RLock()
# all tasks that are not complete or canceled
self._pending_tasks: Dict[Task[Any], TaskData] = {}
# tasks that are ready to execute
self._ready_tasks: Deque[Task[Any]] = deque()
self._tasks_lock = Lock()
self._tasks_lock = threading.Lock()
# This is triggered when wait_for_ready_callbacks should rebuild the wait list
self._guard: Optional[GuardCondition] = GuardCondition(
callback=None, callback_group=None, context=self._context)
# True if shutdown has been called
self._is_shutdown = False
self._work_tracker = _WorkTracker()
# Protect against shutdown() being called in parallel in two threads
self._shutdown_lock = Lock()
self._shutdown_lock = threading.Lock()
# State for wait_for_ready_callbacks to reuse generator
self._cb_iter: Optional[YieldedCallback] = None
self._last_args: Optional[tuple[object, ...]] = None
Expand All @@ -238,24 +307,28 @@ def __init__(self, *, context: Optional[Context] = None) -> None:
# True when the executor is spinning
self._is_spinning = False
# Protects access to _is_spinning
self._is_spinning_lock = Lock()
self._is_spinning_cond = threading.Condition()
self._spinning_thread: Optional[threading.Thread] = None

def _enter_spin(self) -> None:
"""Mark the executor as spinning and prevent concurrent spins."""
with self._is_spinning_lock:
with self._is_spinning_cond:
if self._is_spinning:
raise RuntimeError('Executor is already spinning')
self._is_spinning = True
self._spinning_thread = threading.current_thread()

def _exit_spin(self) -> None:
"""Clear the spinning flag."""
with self._is_spinning_lock:
with self._is_spinning_cond:
self._is_spinning = False
self._spinning_thread = None
self._is_spinning_cond.notify_all()

@property
def is_spinning(self) -> bool:
"""Return whether the executor is currently spinning."""
with self._is_spinning_lock:
with self._is_spinning_cond:
return self._is_spinning

@property
Expand Down Expand Up @@ -315,19 +388,55 @@ def shutdown(self, timeout_sec: Optional[float] = None) -> bool:
timeout expires before all outstanding work is done.
"""
with self._shutdown_lock:
if not self._is_shutdown:
initiated_shutdown = not self._is_shutdown
if initiated_shutdown:
self._is_shutdown = True
# Tell executor it's been shut down
if self._guard:
self._guard.trigger()
if not self._is_shutdown:
if not self._work_tracker.wait(timeout_sec):
return False
# The timeout applies to the whole shutdown operation — both the
# callback drain and the spinner exit — not to each wait
# individually. Convert it into a deadline once; each wait below
# gets only the time remaining against that deadline.
if timeout_sec is None or timeout_sec < 0:
deadline: Optional[float] = None # block forever
else:
deadline = time.monotonic() + timeout_sec

def remaining_timeout() -> Optional[float]:
if deadline is None:
return None
return max(0.0, deadline - time.monotonic())

# Wait for any in-flight callbacks on OTHER threads to drain. Done
# unconditionally (not just for the initiating call) so that:
# - concurrent shutdown() calls don't race past the wait and start
# destroying state while callbacks are still running, and
# - a caller who got False back from a timed-out shutdown() can
# simply call shutdown() again (with a longer or no timeout) and
# have the second call actually wait + finish cleanup.
# _work_tracker.wait excludes work being executed by the calling
# thread, so this is safe from inside a callback — it will not
# self-deadlock.
if not self._work_tracker.wait(remaining_timeout()):
return False

# Clean up stuff that won't be used anymore
with self._nodes_lock:
self._nodes = set()

with self._is_spinning_cond:
if self._spinning_thread is not threading.current_thread():
# Wait for the spin thread to acknowledge shutdown and
# exit before we destroy the guards (which the spinner
# may still be holding in its wait_set). If the wait
# times out, return False per the contract — don't
# destroy resources that the spinner might still touch.
if not self._is_spinning_cond.wait_for(
lambda: not self._is_spinning,
timeout=remaining_timeout()):
return False

with self._shutdown_lock:
if self._guard:
self._guard.destroy()
Expand Down Expand Up @@ -668,7 +777,7 @@ async def handler(entity: 'EntityT', gc: GuardCondition, is_shutdown: bool,
entity._executor_event = False
gc.trigger()
return
with work_tracker:
with work_tracker.track_callback():
# The take_from_wait_list method here is expected to return either an async def
# method or None if there is no work to do.
call_coroutine = take_from_wait_list(entity)
Expand Down Expand Up @@ -1087,7 +1196,7 @@ def __init__(
'Use the SingleThreadedExecutor instead.')
self._futures: List[Future[Any]] = []
self._executor = ThreadPoolExecutor(num_threads)
self._futures_lock = Lock()
self._futures_lock = threading.Lock()

def _spin_once_impl(
self,
Expand Down Expand Up @@ -1157,10 +1266,25 @@ def shutdown(
:param timeout_sec: Seconds to wait. Block forever if ``None`` or negative.
Don't wait if 0.
:param wait_for_threads: If true, this function will block until all executor threads
have joined.
have joined. When shutdown() is called from inside a callback running on one of
this executor's worker threads, the *current* thread is necessarily excluded from
that join (Python cannot join a thread with itself) -- the rest of the callback
will finish after this returns and the worker will exit naturally.
:return: ``True`` if all outstanding callbacks finished executing, or ``False`` if the
timeout expires before all outstanding work is done.
"""
success: bool = super().shutdown(timeout_sec)
self._executor.shutdown(wait=wait_for_threads)
# Always tell the pool to shut down without waiting: if shutdown()
# was called from inside a callback running on one of these
# workers, letting ThreadPoolExecutor.shutdown(wait=True) join the
# current thread would raise RuntimeError. We do the joins below
# ourselves so we can skip the current thread.
self._executor.shutdown(wait=False)
if wait_for_threads:
current = threading.current_thread()
# Snapshot before iterating; _threads is stable post-shutdown
# (no new workers are spawned) but we copy defensively.
for t in list(self._executor._threads):
if t is not current:
t.join()
return success
Loading