|
1 | 1 | # Copyright (c) 2023 Robotics and AI Institute LLC dba RAI Institute. All rights reserved. |
2 | | -from threading import Event |
3 | | -from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar, Union, runtime_checkable |
| 2 | +import threading |
| 3 | +from concurrent.futures import ALL_COMPLETED, FIRST_COMPLETED, FIRST_EXCEPTION |
| 4 | +from typing import ( |
| 5 | + Any, |
| 6 | + Awaitable, |
| 7 | + Callable, |
| 8 | + Iterable, |
| 9 | + Iterator, |
| 10 | + NamedTuple, |
| 11 | + Optional, |
| 12 | + Protocol, |
| 13 | + Set, |
| 14 | + TypeVar, |
| 15 | + Union, |
| 16 | + cast, |
| 17 | + overload, |
| 18 | + runtime_checkable, |
| 19 | +) |
4 | 20 |
|
5 | 21 | from rclpy.clock import Clock |
6 | 22 | from rclpy.context import Context |
| 23 | +from rclpy.duration import Duration |
7 | 24 | from rclpy.utilities import get_default_context |
8 | 25 |
|
9 | 26 | from synchros2.clock import wait_for |
10 | 27 |
|
11 | 28 | T = TypeVar("T", covariant=True) |
12 | 29 |
|
13 | 30 |
|
| 31 | +@runtime_checkable |
14 | 32 | class FutureLike(Awaitable[T], Protocol[T]): |
15 | 33 | """A future-like awaitable object. |
16 | 34 |
|
@@ -61,58 +79,266 @@ def as_proper_future(instance: AnyFuture) -> FutureLike: |
61 | 79 | return instance |
62 | 80 |
|
63 | 81 |
|
| 82 | +class WaitResult(NamedTuple): |
| 83 | + """Result of waiting for multiple futures. |
| 84 | +
|
| 85 | + A named tuple with 'done' and 'not_done' sets of futures. |
| 86 | + """ |
| 87 | + |
| 88 | + ok: bool |
| 89 | + done: Set[FutureLike] |
| 90 | + not_done: Set[FutureLike] |
| 91 | + |
| 92 | + def __bool__(self) -> bool: |
| 93 | + """Equivalent to result.ok.""" |
| 94 | + return self.ok |
| 95 | + |
| 96 | + |
| 97 | +@overload |
64 | 98 | def wait_for_future( |
65 | 99 | future: AnyFuture, |
66 | 100 | timeout_sec: Optional[float] = None, |
67 | 101 | *, |
68 | 102 | clock: Optional[Clock] = None, |
69 | 103 | context: Optional[Context] = None, |
70 | | -) -> bool: |
71 | | - """Block while waiting for a future to become done |
| 104 | +) -> WaitResult: |
| 105 | + ... |
| 106 | + |
| 107 | + |
| 108 | +@overload |
| 109 | +def wait_for_future( |
| 110 | + future: Iterable[AnyFuture], |
| 111 | + timeout_sec: Optional[float] = None, |
| 112 | + *, |
| 113 | + return_when: str = ALL_COMPLETED, |
| 114 | + clock: Optional[Clock] = None, |
| 115 | + context: Optional[Context] = None, |
| 116 | +) -> WaitResult: |
| 117 | + ... |
| 118 | + |
| 119 | + |
| 120 | +def wait_for_future( |
| 121 | + future: Union[AnyFuture, Iterable[AnyFuture]], |
| 122 | + timeout_sec: Optional[float] = None, |
| 123 | + *, |
| 124 | + clock: Optional[Clock] = None, |
| 125 | + context: Optional[Context] = None, |
| 126 | + return_when: str = ALL_COMPLETED, |
| 127 | +) -> WaitResult: |
| 128 | + """Block while waiting for future(s) to become done. |
72 | 129 |
|
73 | 130 | Args: |
74 | | - future (Future): The future to be waited on |
75 | | - timeout_sec (Optional[float]): An optional timeout for how long to wait |
76 | | - clock (Optional[Clock]): An optional clock to use for timeout waits, |
77 | | - defaults to the clock of the current scope if any, otherwise the system clock |
78 | | - context (Optional[Context]): Current context (will use the default if none is given) |
| 131 | + future: A single future or an iterable of futures to wait on |
| 132 | + timeout_sec: An optional timeout for how long to wait |
| 133 | + clock: An optional clock to use for timeout waits, |
| 134 | + defaults to the clock of the current scope if any, otherwise the system clock |
| 135 | + context: Current context (will use the default if none is given) |
| 136 | + return_when: One of FIRST_COMPLETED, FIRST_EXCEPTION, or ALL_COMPLETED. |
| 137 | + Only applies when waiting for multiple futures. Defaults to ALL_COMPLETED. |
79 | 138 |
|
80 | 139 | Returns: |
81 | | - bool: True if successful, False if the timeout was triggered |
| 140 | + A result object indicating which futures are done and which are not, |
| 141 | + and whether the wait was successful (i.e. not timed out). |
| 142 | +
|
| 143 | + Examples: |
| 144 | + Single future: |
| 145 | + >>> result = wait_for_future(my_future, timeout_sec=5.0) |
| 146 | + >>> if result: |
| 147 | + ... value = my_future.result() |
| 148 | +
|
| 149 | + Multiple futures: |
| 150 | + >>> result = wait_for_future([f1, f2, f3], return_when=FIRST_COMPLETED) |
| 151 | + >>> for future in result.done: |
| 152 | + ... print(future.result()) |
82 | 153 | """ |
| 154 | + if return_when not in {FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED}: |
| 155 | + raise ValueError(f"Invalid return_when value: {return_when}") |
| 156 | + |
83 | 157 | if context is None: |
84 | 158 | context = get_default_context() |
| 159 | + |
85 | 160 | if clock is None: |
86 | 161 | import synchros2.scope |
87 | 162 |
|
88 | 163 | clock = synchros2.scope.clock() |
89 | | - event = Event() |
| 164 | + |
| 165 | + done_futures: Set[FutureLike] = set() |
| 166 | + if not isinstance(future, (FutureConvertible, FutureLike)): |
| 167 | + pending_futures = {as_proper_future(f) for f in future} |
| 168 | + else: |
| 169 | + pending_futures = {as_proper_future(future)} |
| 170 | + |
| 171 | + if not pending_futures: |
| 172 | + return WaitResult(ok=True, done=set(), not_done=set()) |
| 173 | + |
| 174 | + lock = threading.Lock() |
| 175 | + event = threading.Event() |
| 176 | + |
| 177 | + def _done_callback(future: FutureLike) -> None: |
| 178 | + with lock: |
| 179 | + if future in pending_futures: |
| 180 | + pending_futures.remove(future) |
| 181 | + done_futures.add(future) |
| 182 | + |
| 183 | + should_return = False |
| 184 | + if return_when == FIRST_COMPLETED: |
| 185 | + should_return = True |
| 186 | + elif return_when == FIRST_EXCEPTION: |
| 187 | + exception_occurred = future.exception() is not None |
| 188 | + should_return = exception_occurred or not pending_futures |
| 189 | + elif return_when == ALL_COMPLETED: |
| 190 | + should_return = not pending_futures |
| 191 | + |
| 192 | + if should_return: |
| 193 | + event.set() |
| 194 | + |
90 | 195 | context.on_shutdown(event.set) |
91 | | - proper_future = as_proper_future(future) |
92 | | - proper_future.add_done_callback(lambda _: event.set()) |
93 | | - if proper_future.cancelled(): |
94 | | - event.set() |
95 | | - wait_for(event, clock=clock, timeout_sec=timeout_sec) |
96 | | - return proper_future.done() |
| 196 | + for future in list(pending_futures): |
| 197 | + future.add_done_callback(_done_callback) |
| 198 | + if future.cancelled(): |
| 199 | + _done_callback(future) |
| 200 | + |
| 201 | + if not event.is_set(): |
| 202 | + wait_for(event, clock=clock, timeout_sec=timeout_sec) |
97 | 203 |
|
| 204 | + with lock: |
| 205 | + return WaitResult(ok=event.is_set(), done=done_futures.copy(), not_done=pending_futures.copy()) |
98 | 206 |
|
| 207 | + |
| 208 | +@overload |
99 | 209 | def unwrap_future( |
100 | 210 | future: AnyFuture, |
101 | 211 | timeout_sec: Optional[float] = None, |
102 | 212 | *, |
103 | 213 | clock: Optional[Clock] = None, |
104 | 214 | context: Optional[Context] = None, |
105 | 215 | ) -> Any: |
106 | | - """Fetch future result when it is done. |
| 216 | + ... |
| 217 | + |
| 218 | + |
| 219 | +@overload |
| 220 | +def unwrap_future( |
| 221 | + future: Iterable[AnyFuture], |
| 222 | + timeout_sec: Optional[float] = None, |
| 223 | + *, |
| 224 | + clock: Optional[Clock] = None, |
| 225 | + context: Optional[Context] = None, |
| 226 | + strict: bool = False, |
| 227 | +) -> Iterator[Any]: |
| 228 | + ... |
| 229 | + |
| 230 | + |
| 231 | +def unwrap_future( |
| 232 | + future: Union[AnyFuture, Iterable[AnyFuture]], |
| 233 | + timeout_sec: Optional[float] = None, |
| 234 | + *, |
| 235 | + clock: Optional[Clock] = None, |
| 236 | + context: Optional[Context] = None, |
| 237 | + strict: bool = False, |
| 238 | +) -> Union[Any, Iterator[Any]]: |
| 239 | + """Fetch future result(s) when done. |
| 240 | +
|
| 241 | + For a single future, blocks until the future is done and returns its result. |
| 242 | + For multiple futures, returns a generator that yields results as futures complete |
| 243 | + (like concurrent.futures.as_completed). |
| 244 | +
|
| 245 | + Note: This function may block and may raise if a future raises or it times out |
| 246 | + waiting. See wait_for_future() documentation for further reference on arguments. |
| 247 | +
|
| 248 | + Args: |
| 249 | + future: A single future or an iterable of futures |
| 250 | + timeout_sec: An optional timeout for how long to wait |
| 251 | + clock: An optional clock to use for timeout waits |
| 252 | + context: Current context (will use the default if none is given) |
| 253 | + strict: If True, yield results in order regardless of completion order. |
| 254 | + If False (default), yield results as they complete. |
| 255 | + Irrelevant when a single future is provided. |
| 256 | +
|
| 257 | + Returns: |
| 258 | + the result(s) of the future(s) when they are done. |
| 259 | +
|
| 260 | + Raises: |
| 261 | + ValueError: If timeout occurs before future(s) complete |
| 262 | +
|
| 263 | + Examples: |
| 264 | + Single future: |
| 265 | + >>> result = unwrap_future(my_future, timeout_sec=5.0) |
107 | 266 |
|
108 | | - Note this function may block and may raise if the future does or it times out |
109 | | - waiting for it. See wait_for_future() documentation for further reference on |
110 | | - arguments taken. |
| 267 | + Multiple futures (non-strict, as completed): |
| 268 | + >>> for result in unwrap_future([f1, f2, f3], timeout_sec=10.0): |
| 269 | + ... process(result) |
| 270 | +
|
| 271 | + Multiple futures (strict, in order): |
| 272 | + >>> for result in unwrap_future([f1, f2, f3], timeout_sec=10.0, strict=True): |
| 273 | + ... process(result) |
111 | 274 | """ |
112 | | - proper_future = as_proper_future(future) |
113 | | - if not wait_for_future(proper_future, timeout_sec, clock=clock, context=context): |
114 | | - raise ValueError("cannot unwrap future that is not done") |
115 | | - return proper_future.result() |
| 275 | + if context is None: |
| 276 | + context = get_default_context() |
| 277 | + |
| 278 | + if clock is None: |
| 279 | + import synchros2.scope |
| 280 | + |
| 281 | + clock = synchros2.scope.clock() |
| 282 | + |
| 283 | + if isinstance(future, (FutureConvertible, FutureLike)): |
| 284 | + proper_future = as_proper_future(future) |
| 285 | + if not wait_for_future(proper_future, timeout_sec, clock=clock, context=context): |
| 286 | + raise ValueError("cannot unwrap future that is not done") |
| 287 | + return proper_future.result() |
| 288 | + |
| 289 | + def _result_generator() -> Any: |
| 290 | + nonlocal future |
| 291 | + future = cast(Iterable[AnyFuture], future) |
| 292 | + pending_futures = [as_proper_future(f) for f in future] |
| 293 | + if not pending_futures: |
| 294 | + return |
| 295 | + |
| 296 | + deadline = None |
| 297 | + if timeout_sec is not None: |
| 298 | + assert clock is not None |
| 299 | + deadline = clock.now() + Duration(seconds=timeout_sec) |
| 300 | + |
| 301 | + if strict: |
| 302 | + for future in pending_futures: |
| 303 | + remaining_timeout_sec = None |
| 304 | + if deadline is not None: |
| 305 | + assert clock is not None |
| 306 | + remaining_duration = deadline - clock.now() |
| 307 | + if remaining_duration.nanoseconds <= 0: |
| 308 | + raise ValueError("timeout waiting for futures") |
| 309 | + remaining_timeout_sec = remaining_duration.nanoseconds / 1e9 |
| 310 | + |
| 311 | + if not wait_for_future(future, timeout_sec=remaining_timeout_sec, clock=clock, context=context): |
| 312 | + raise ValueError("timeout waiting for futures") |
| 313 | + yield future.result() |
| 314 | + return |
| 315 | + |
| 316 | + while pending_futures: |
| 317 | + remaining_timeout_sec = None |
| 318 | + if deadline is not None: |
| 319 | + assert clock is not None |
| 320 | + remaining_duration = deadline - clock.now() |
| 321 | + if remaining_duration.nanoseconds <= 0: |
| 322 | + raise ValueError("timeout waiting for futures") |
| 323 | + remaining_timeout_sec = remaining_duration.nanoseconds / 1e9 |
| 324 | + |
| 325 | + result = wait_for_future( |
| 326 | + pending_futures, |
| 327 | + timeout_sec=remaining_timeout_sec, |
| 328 | + clock=clock, |
| 329 | + context=context, |
| 330 | + return_when=FIRST_COMPLETED, |
| 331 | + ) |
| 332 | + |
| 333 | + if not result: |
| 334 | + raise ValueError("timeout waiting for futures") |
| 335 | + |
| 336 | + for future in result.done: |
| 337 | + if future in pending_futures: |
| 338 | + pending_futures.remove(future) |
| 339 | + yield future.result() |
| 340 | + |
| 341 | + return _result_generator() |
116 | 342 |
|
117 | 343 |
|
118 | 344 | wait_and_return_result = unwrap_future |
|
0 commit comments