Skip to content

Commit 8b54843

Browse files
kovanclaude
andcommitted
Add TaskGroup.start() and TaskStatus for task readiness signaling
Add TaskStatus class with started() method and TaskGroup.start() async method. The task calls task_status.started(value) to signal it is ready, and start() returns that value to the caller. The task continues running in the group after signaling readiness. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2a33256 commit 8b54843

2 files changed

Lines changed: 162 additions & 1 deletion

File tree

Lib/asyncio/taskgroups.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,40 @@
22
# license: PSFL.
33

44

5-
__all__ = ("TaskGroup",)
5+
__all__ = ("TaskGroup", "TaskStatus")
66

77
from . import events
88
from . import exceptions
99
from . import futures
1010
from . import tasks
1111

1212

13+
class TaskStatus:
14+
"""Status object passed to tasks started via :meth:`TaskGroup.start`.
15+
16+
The task calls :meth:`started` to signal readiness, passing an
17+
optional value back to the ``start()`` caller.
18+
"""
19+
20+
def __init__(self):
21+
self._future = None # set by TaskGroup.start()
22+
self._started = False
23+
24+
def started(self, value=None):
25+
"""Signal that the task is ready.
26+
27+
*value* is returned to the ``await TaskGroup.start(...)`` caller.
28+
May only be called once.
29+
"""
30+
if self._started:
31+
raise RuntimeError("task already signalled readiness")
32+
if self._future is None or self._future.done():
33+
raise RuntimeError(
34+
"TaskStatus is not associated with a pending start()")
35+
self._started = True
36+
self._future.set_result(value)
37+
38+
1339
class TaskGroup:
1440
"""Asynchronous context manager for managing groups of tasks.
1541
@@ -210,6 +236,39 @@ def create_task(self, coro, **kwargs):
210236
# task.exception().__traceback__->TaskGroup.create_task->task
211237
del task
212238

239+
async def start(self, coro_fn, *args, name=None, context=None):
240+
"""Start a task and wait until it signals readiness.
241+
242+
*coro_fn* is called as ``coro_fn(*args, task_status=task_status)``.
243+
The coroutine must call ``task_status.started(value)`` to signal
244+
readiness. The *value* passed to ``started()`` is returned by
245+
this method. The task continues running in the group after
246+
``started()`` is called.
247+
"""
248+
if not self._entered:
249+
raise RuntimeError(f"TaskGroup {self!r} has not been entered")
250+
if self._exiting and not self._tasks:
251+
raise RuntimeError(f"TaskGroup {self!r} is finished")
252+
if self._aborting:
253+
raise RuntimeError(f"TaskGroup {self!r} is shutting down")
254+
255+
task_status = TaskStatus()
256+
task_status._future = self._loop.create_future()
257+
258+
coro = coro_fn(*args, task_status=task_status)
259+
kwargs = {}
260+
if name is not None:
261+
kwargs['name'] = name
262+
if context is not None:
263+
kwargs['context'] = context
264+
task = self.create_task(coro, **kwargs)
265+
266+
try:
267+
return await task_status._future
268+
except BaseException:
269+
task.cancel()
270+
raise
271+
213272
# Since Python 3.8 Tasks propagate all exceptions correctly,
214273
# except for KeyboardInterrupt and SystemExit which are
215274
# still considered special.

Lib/test/test_asyncio/test_taskgroups.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,6 +1103,108 @@ async def throw_error():
11031103
await asyncio.sleep(0)
11041104

11051105

1106+
# -- TaskGroup.start() / TaskStatus ------------------------------------
1107+
1108+
async def test_start_basic(self):
1109+
"""start() returns the value passed to task_status.started()."""
1110+
async def server(task_status):
1111+
task_status.started(42)
1112+
await asyncio.sleep(0)
1113+
1114+
async with taskgroups.TaskGroup() as tg:
1115+
value = await tg.start(server)
1116+
self.assertEqual(value, 42)
1117+
1118+
async def test_start_none_value(self):
1119+
"""started() with no arg returns None."""
1120+
async def worker(task_status):
1121+
task_status.started()
1122+
await asyncio.sleep(0)
1123+
1124+
async with taskgroups.TaskGroup() as tg:
1125+
value = await tg.start(worker)
1126+
self.assertIsNone(value)
1127+
1128+
async def test_start_task_continues(self):
1129+
"""Task keeps running after started() is called."""
1130+
finished = False
1131+
1132+
async def worker(task_status):
1133+
nonlocal finished
1134+
task_status.started("ready")
1135+
await asyncio.sleep(0)
1136+
finished = True
1137+
1138+
async with taskgroups.TaskGroup() as tg:
1139+
value = await tg.start(worker)
1140+
self.assertEqual(value, "ready")
1141+
# TaskGroup waits for the task to finish
1142+
self.assertTrue(finished)
1143+
1144+
async def test_start_error_before_started(self):
1145+
"""Exception before started() propagates through the group."""
1146+
async def failing(task_status):
1147+
raise RuntimeError("boom")
1148+
1149+
with self.assertRaises(ExceptionGroup) as cm:
1150+
async with taskgroups.TaskGroup() as tg:
1151+
await tg.start(failing)
1152+
1153+
self.assertEqual(len(cm.exception.exceptions), 1)
1154+
self.assertIsInstance(cm.exception.exceptions[0], RuntimeError)
1155+
1156+
async def test_start_cancelled_before_started(self):
1157+
"""If the task is cancelled before started(), cancellation propagates."""
1158+
async def slow_start(task_status):
1159+
await asyncio.sleep(100)
1160+
task_status.started()
1161+
1162+
with self.assertRaises(TimeoutError):
1163+
async with asyncio.timeout(0.01):
1164+
async with taskgroups.TaskGroup() as tg:
1165+
await tg.start(slow_start)
1166+
1167+
async def test_start_already_started_error(self):
1168+
"""Calling started() twice raises RuntimeError."""
1169+
async def double_start(task_status):
1170+
task_status.started(1)
1171+
with self.assertRaises(RuntimeError):
1172+
task_status.started(2)
1173+
1174+
async with taskgroups.TaskGroup() as tg:
1175+
value = await tg.start(double_start)
1176+
self.assertEqual(value, 1)
1177+
1178+
async def test_start_multiple_tasks(self):
1179+
"""Multiple start() calls in the same group."""
1180+
async def worker(n, task_status):
1181+
task_status.started(n * 10)
1182+
await asyncio.sleep(0)
1183+
1184+
async with taskgroups.TaskGroup() as tg:
1185+
v1 = await tg.start(worker, 1)
1186+
v2 = await tg.start(worker, 2)
1187+
v3 = await tg.start(worker, 3)
1188+
1189+
self.assertEqual(v1, 10)
1190+
self.assertEqual(v2, 20)
1191+
self.assertEqual(v3, 30)
1192+
1193+
async def test_start_with_name(self):
1194+
"""start() passes name= to create_task."""
1195+
task_name = None
1196+
1197+
async def worker(task_status):
1198+
nonlocal task_name
1199+
task_name = asyncio.current_task().get_name()
1200+
task_status.started()
1201+
1202+
async with taskgroups.TaskGroup() as tg:
1203+
await tg.start(worker, name="my-worker")
1204+
1205+
self.assertEqual(task_name, "my-worker")
1206+
1207+
11061208
class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
11071209
loop_factory = asyncio.EventLoop
11081210

0 commit comments

Comments
 (0)