Skip to content

Commit f0bbc2b

Browse files
authored
Improve consistency of timers and rates (#203)
Signed-off-by: Michel Hidalgo <mhidalgo@theaiinstitute.com>
1 parent 4dce9d6 commit f0bbc2b

5 files changed

Lines changed: 341 additions & 21 deletions

File tree

synchros2/synchros2/executors.py

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import rclpy.callback_groups
1818
import rclpy.executors
1919
import rclpy.node
20+
import rclpy.timer
2021

2122
from synchros2.futures import FutureLike
2223
from synchros2.utilities import bind_to_thread, fqn
@@ -592,6 +593,7 @@ def __init__(
592593
max_thread_idle_time: typing.Optional[float] = None,
593594
max_threads_per_callback_group: typing.Optional[int] = None,
594595
*,
596+
num_threads_for_timers: typing.Optional[int] = None,
595597
context: typing.Optional[rclpy.context.Context] = None,
596598
logger: typing.Optional[logging.Logger] = None,
597599
) -> None:
@@ -607,24 +609,41 @@ def __init__(
607609
max_threads_per_callback_group: optional maximum number of concurrent callbacks the
608610
default thread pool should service for a given callback group. Useful to avoid
609611
reentrant callback groups from starving the default thread pool.
612+
num_threads_for_timers: optional number of threads to dedicate to timer callbacks.
613+
Defaults to 10% of all available threads, which may be 0 if there are less than
614+
10 threads, in which case timer callbacks will be serviced by the default thread pool.
610615
context: An optional instance of the ros context.
611616
logger: An optional logger instance.
612617
"""
613618
super().__init__(context=context)
614619
if logger is None:
615620
logger = rclpy.logging.get_logger(fqn(self.__class__))
621+
if max_threads is None:
622+
max_threads = 32 * (os.cpu_count() or 1)
623+
if num_threads_for_timers is None:
624+
num_threads_for_timers = max_threads // 10
625+
if num_threads_for_timers == 0:
626+
logger.warning("Not enough threads available, timers will be serviced by the default thread pool")
627+
max_threads -= num_threads_for_timers
616628
self._logger = logger
617629
self._is_shutdown = False
618630
self._spin_lock = threading.Lock()
619631
self._shutdown_lock = threading.RLock()
620-
self._thread_pools = [
621-
AutoScalingThreadPool(
622-
max_workers=max_threads,
623-
max_idle_time=max_thread_idle_time,
632+
self._default_thread_pool = AutoScalingThreadPool(
633+
max_workers=max_threads,
634+
max_idle_time=max_thread_idle_time,
635+
submission_quota=max_threads_per_callback_group,
636+
logger=self._logger,
637+
)
638+
self._timers_thread_pool: typing.Optional[AutoScalingThreadPool] = None
639+
if num_threads_for_timers != 0:
640+
self._timers_thread_pool = AutoScalingThreadPool(
641+
min_workers=num_threads_for_timers,
642+
max_workers=num_threads_for_timers,
624643
submission_quota=max_threads_per_callback_group,
625644
logger=self._logger,
626-
),
627-
]
645+
)
646+
self._static_thread_pools: typing.List[AutoScalingThreadPool] = []
628647
self._callback_group_affinity: weakref.WeakKeyDictionary[
629648
rclpy.callback_groups.CallbackGroup,
630649
AutoScalingThreadPool,
@@ -637,12 +656,21 @@ def __init__(
637656
@property
638657
def default_thread_pool(self) -> AutoScalingThreadPool:
639658
"""Default autoscaling thread pool."""
640-
return self._thread_pools[0]
659+
return self._default_thread_pool
660+
661+
@property
662+
def timers_thread_pool(self) -> typing.Optional[AutoScalingThreadPool]:
663+
"""Autoscaling thread pool for timer callbacks."""
664+
return self._timers_thread_pool
641665

642666
@property
643667
def thread_pools(self) -> typing.List[AutoScalingThreadPool]:
644668
"""Autoscaling thread pools in use."""
645-
return list(self._thread_pools)
669+
thread_pools = [self._default_thread_pool]
670+
if self._timers_thread_pool is not None:
671+
thread_pools.append(self._timers_thread_pool)
672+
thread_pools.extend(self._static_thread_pools)
673+
return thread_pools
646674

647675
def add_static_thread_pool(self, num_threads: typing.Optional[int] = None) -> AutoScalingThreadPool:
648676
"""Add a thread pool that keeps a steady number of workers."""
@@ -653,8 +681,8 @@ def add_static_thread_pool(self, num_threads: typing.Optional[int] = None) -> Au
653681
max_workers=num_threads,
654682
logger=self._logger,
655683
)
656-
self._thread_pools.append(thread_pool)
657-
self._logger.debug(f"Added static thread pool #{len(self._thread_pools) - 1}")
684+
self._static_thread_pools.append(thread_pool)
685+
self._logger.debug(f"Added static thread pool #{len(self._static_thread_pools) - 1}")
658686
return thread_pool
659687

660688
def bind(self, callback_group: rclpy.callback_groups.CallbackGroup, thread_pool: AutoScalingThreadPool) -> None:
@@ -663,9 +691,13 @@ def bind(self, callback_group: rclpy.callback_groups.CallbackGroup, thread_pool:
663691
Thread pool must be known to the executor. That is, instantiated through add_*_thread_pool() methods.
664692
"""
665693
with self._shutdown_lock:
666-
if thread_pool not in self._thread_pools:
694+
if thread_pool not in self._static_thread_pools:
695+
if thread_pool is self._default_thread_pool:
696+
raise ValueError("cannot rebind to default thread pool")
697+
if thread_pool is self._timers_thread_pool:
698+
raise ValueError("cannot bind to timers thread pool")
667699
raise ValueError("thread pool unknown to executor")
668-
thread_pool_index = self._thread_pools.index(thread_pool)
700+
thread_pool_index = self._static_thread_pools.index(thread_pool)
669701
callback_group_name = f"{fqn(type(callback_group))}@{id(callback_group)}"
670702
self._logger.debug(f"Binding {callback_group_name} to thread pool #{thread_pool_index}...")
671703
self._callback_group_affinity[callback_group] = thread_pool
@@ -698,14 +730,16 @@ def _do_spin_once(self, *args: typing.Any, **kwargs: typing.Any) -> None:
698730
# dispatch and be missed. Fortunately, this will only delay dispatch until the
699731
# next spin cycle.
700732
if task not in self._work_in_progress or (self._work_in_progress[task].done() and not task.done()):
701-
if task.callback_group is not None:
702-
if task.callback_group not in self._callback_group_affinity:
703-
self._callback_group_affinity[task.callback_group] = self._thread_pools[0]
733+
if task.callback_group is not None and task.callback_group in self._callback_group_affinity:
704734
thread_pool = self._callback_group_affinity[task.callback_group]
735+
thread_pool_index = self._static_thread_pools.index(thread_pool)
736+
self._logger.debug(f"Task '{task}' submitted to static thread pool #{thread_pool_index}")
737+
elif self._timers_thread_pool is not None and isinstance(task.entity, rclpy.timer.Timer):
738+
thread_pool = self._timers_thread_pool
739+
self._logger.debug(f"Task '{task}' submitted to timers thread pool")
705740
else:
706-
thread_pool = self._thread_pools[0]
707-
thread_pool_index = self._thread_pools.index(thread_pool)
708-
self._logger.debug(f"Task '{task}' submitted to thread pool #{thread_pool_index}")
741+
thread_pool = self._default_thread_pool
742+
self._logger.debug(f"Task '{task}' submitted to default thread pool")
709743
self._work_in_progress[task] = thread_pool.submit(task)
710744
for task in list(self._work_in_progress):
711745
if not task.done():
@@ -781,10 +815,11 @@ def shutdown(self, timeout_sec: typing.Optional[float] = None) -> bool:
781815
# must be waited on. Work tracking in rclpy.executors.Executor
782816
# base implementation is subject to races, so block thread pool
783817
# submissions and wait for all futures to finish. Then shutdown.
784-
done = all(thread_pool.wait(timeout_sec) for thread_pool in self._thread_pools)
818+
819+
done = all(thread_pool.wait(timeout_sec) for thread_pool in self.thread_pools)
785820
if done:
786821
assert super().shutdown(timeout_sec=0)
787-
for thread_pool in self._thread_pools:
822+
for thread_pool in self.thread_pools:
788823
thread_pool.shutdown()
789824
self._is_shutdown = True
790825
if done:

synchros2/synchros2/node.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,22 @@
33
import functools
44
from typing import Any, Callable, Iterable, Optional, Type
55

6+
try:
7+
from typing import override # type: ignore[attr-defined]
8+
except ImportError:
9+
from typing_extensions import override # type: ignore[import]
10+
11+
612
from rclpy.callback_groups import CallbackGroup
13+
from rclpy.clock import Clock
714
from rclpy.exceptions import InvalidHandle
815
from rclpy.node import Node as BaseNode
16+
from rclpy.timer import Rate
917
from rclpy.waitable import Waitable
1018

1119
from synchros2.callback_groups import NonReentrantCallbackGroup
1220
from synchros2.logging import MemoizingRcutilsLogger
21+
from synchros2.time import SteadyRate
1322

1423

1524
def suppressed(exception: Type[BaseException], func: Callable) -> Callable:
@@ -55,6 +64,32 @@ def default_callback_group(self) -> CallbackGroup:
5564
# NOTE(hidmic): this overrides the hardcoded default group in rclpy.node.Node implementation
5665
return self._default_callback_group_override
5766

67+
@override
68+
def create_rate(
69+
self,
70+
frequency: float,
71+
clock: Optional[Clock] = None,
72+
) -> Rate:
73+
"""Create a Rate object.
74+
75+
:param frequency: The frequency the Rate runs at (Hz).
76+
:param clock: The clock the Rate gets time from.
77+
"""
78+
if clock is None:
79+
clock = self.get_clock()
80+
return SteadyRate(frequency, clock, context=self._context)
81+
82+
@override
83+
def destroy_rate(self, rate: Rate) -> bool:
84+
"""Destroy a Rate object created by the node.
85+
86+
:return: ``True`` if successful, ``False`` otherwise.
87+
"""
88+
if isinstance(rate, SteadyRate):
89+
rate.destroy()
90+
return True
91+
return super().destroy_rate(rate)
92+
5893
@property
5994
def waitables(self) -> Iterable[Waitable]:
6095
"""Get patched node waitables.

synchros2/synchros2/time.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
11
# Copyright (c) 2024 Robotics and AI Institute LLC dba RAI Institute. All rights reserved.
22

3+
import threading
34
from datetime import datetime, timedelta
4-
from typing import Union
5+
from typing import Optional, Union
56

7+
try:
8+
from typing import override # type: ignore[attr-defined]
9+
except ImportError:
10+
from typing_extensions import override # type: ignore[import]
11+
12+
from rclpy.context import Context
613
from rclpy.duration import Duration
14+
from rclpy.exceptions import ROSInterruptException
715
from rclpy.time import Time
16+
from rclpy.timer import Rate
17+
from rclpy.utilities import get_default_context
818

919

1020
def as_proper_time(time: Union[int, float, datetime, Time]) -> Time:
@@ -57,3 +67,66 @@ def as_proper_duration(duration: Union[int, float, timedelta, Duration]) -> Dura
5767
if not isinstance(duration, Duration):
5868
raise ValueError(f"unsupported duration type: {duration}")
5969
return duration
70+
71+
72+
class SteadyRate(Rate):
73+
"""An rclpy.Rate equivalent that uses clock functionality directly, without timer overhead."""
74+
75+
def __init__(self, frequency: float, clock: Time, *, context: Optional[Context] = None) -> None:
76+
# NOTE: SteadyRate subclasses Rate for type consistency but does not use any of its functionality.
77+
# Thus, we skip the constructor call entirely.
78+
self._clock = clock
79+
if context is None:
80+
context = get_default_context()
81+
self._context = context
82+
self._period = as_proper_duration(1.0 / frequency)
83+
self._deadline = self._clock.now() + self._period
84+
85+
self._lock = threading.Lock()
86+
self._num_sleepers = 0
87+
88+
self._is_shutdown = False
89+
self._is_destroyed = False
90+
self._context.on_shutdown(self._on_shutdown)
91+
92+
@override
93+
def _on_shutdown(self) -> None:
94+
self._is_shutdown = True
95+
self.destroy()
96+
97+
@override
98+
def destroy(self) -> None:
99+
"""Destroy the rate."""
100+
self._is_destroyed = True
101+
102+
@override
103+
def _presleep(self) -> None:
104+
if self._is_shutdown:
105+
raise ROSInterruptException()
106+
if self._is_destroyed:
107+
raise RuntimeError("MonotonicRate cannot sleep because it has been destroyed")
108+
with self._lock:
109+
self._num_sleepers += 1
110+
111+
@override
112+
def _postsleep(self) -> None:
113+
with self._lock:
114+
self._num_sleepers -= 1
115+
if self._num_sleepers == 0:
116+
now = self._clock.now()
117+
next_deadline = self._deadline + self._period
118+
if now < self._deadline or now > next_deadline:
119+
next_deadline = now + self._period
120+
self._deadline = next_deadline
121+
if self._is_shutdown:
122+
self.destroy()
123+
raise ROSInterruptException()
124+
125+
@override
126+
def sleep(self) -> None:
127+
"""Block until the current period is over."""
128+
self._presleep()
129+
try:
130+
self._clock.sleep_until(self._deadline, context=self._context)
131+
finally:
132+
self._postsleep()

synchros2/test/test_executors.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,45 @@ def deferred() -> bool:
241241
assert future.result()
242242

243243

244+
def test_autoscaling_executor_with_timers_thread_pool(ros_context: Context, ros_node: Node) -> None:
245+
"""Asserts that the autoscaling multithreaded executor routes timer callbacks to the
246+
dedicated timers thread pool and leaves non-timer work to the default thread pool.
247+
"""
248+
with background(
249+
AutoScalingMultiThreadedExecutor(
250+
context=ros_context,
251+
num_threads_for_timers=1,
252+
logger=logging.root,
253+
),
254+
) as executor:
255+
assert executor.timers_thread_pool is not None
256+
executor.add_node(ros_node)
257+
258+
timer_threads: List[threading.Thread] = []
259+
task_threads: List[threading.Thread] = []
260+
261+
def timer_callback() -> None:
262+
timer_threads.append(threading.current_thread())
263+
264+
ros_node.create_timer(0.05, timer_callback, ReentrantCallbackGroup())
265+
266+
def task_callback() -> None:
267+
task_threads.append(threading.current_thread())
268+
time.sleep(0.05)
269+
270+
for _ in range(5):
271+
executor.create_task(task_callback)
272+
273+
time.sleep(1.0)
274+
275+
assert len(timer_threads) > 0
276+
assert len(task_threads) > 0
277+
# All timer callbacks must run on the same single timers-pool thread
278+
assert all(t is timer_threads[0] for t in timer_threads[1:])
279+
# Task callbacks must never have run on the timers-pool thread
280+
assert not any(t is timer_threads[0] for t in task_threads)
281+
282+
244283
@pytest.mark.filterwarnings("ignore")
245284
def test_background_executor_shows_errors(ros_context: Context, ros_node: Node) -> None:
246285
"""Asserts that an background executor does not swallow callback exceptions."""

0 commit comments

Comments
 (0)