|
1 | 1 | """Query class for handling bidirectional control protocol.""" |
2 | 2 |
|
| 3 | +import asyncio |
3 | 4 | import json |
4 | 5 | import logging |
5 | 6 | import os |
|
30 | 31 |
|
31 | 32 | logger = logging.getLogger(__name__) |
32 | 33 |
|
| 34 | +_READER_SHUTDOWN_TIMEOUT = 5.0 |
| 35 | + |
33 | 36 |
|
34 | 37 | def _convert_hook_output_for_cli(hook_output: dict[str, Any]) -> dict[str, Any]: |
35 | 38 | """Convert Python-safe field names to CLI-expected field names. |
@@ -116,6 +119,14 @@ def __init__( |
116 | 119 | float(os.environ.get("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT", "60000")) / 1000.0 |
117 | 120 | ) # Convert ms to seconds |
118 | 121 |
|
| 122 | + # Reader lifecycle — _run_reader owns the anyio task group so that |
| 123 | + # enter and exit always happen in the same asyncio task, avoiding |
| 124 | + # the cross-task RuntimeError from anyio's CancelScope. |
| 125 | + self._reader_task: asyncio.Task[None] | None = None |
| 126 | + self._reader_ready = asyncio.Event() |
| 127 | + self._reader_done = asyncio.Event() |
| 128 | + self._reader_start_exc: BaseException | None = None |
| 129 | + |
119 | 130 | async def initialize(self) -> dict[str, Any] | None: |
120 | 131 | """Initialize control protocol if in streaming mode. |
121 | 132 |
|
@@ -164,10 +175,39 @@ async def initialize(self) -> dict[str, Any] | None: |
164 | 175 |
|
165 | 176 | async def start(self) -> None: |
166 | 177 | """Start reading messages from transport.""" |
167 | | - if self._tg is None: |
168 | | - self._tg = anyio.create_task_group() |
169 | | - await self._tg.__aenter__() |
170 | | - self._tg.start_soon(self._read_messages) |
| 178 | + if self._reader_task is not None: |
| 179 | + return |
| 180 | + |
| 181 | + self._reader_ready.clear() |
| 182 | + self._reader_done.clear() |
| 183 | + self._reader_start_exc = None |
| 184 | + self._reader_task = asyncio.create_task(self._run_reader()) |
| 185 | + await self._reader_ready.wait() |
| 186 | + if self._reader_start_exc is not None: |
| 187 | + raise self._reader_start_exc |
| 188 | + |
| 189 | + async def _run_reader(self) -> None: |
| 190 | + """Owns the anyio task group — enter and exit happen in this task. |
| 191 | +
|
| 192 | + The task group is entered and exited here so that anyio's |
| 193 | + CancelScope never sees a cross-task mismatch. Child tasks |
| 194 | + (_read_messages, _handle_control_request, stream_input) are |
| 195 | + started inside this group via _tg.start_soon(). When the |
| 196 | + transport closes, _read_messages finishes, which lets the |
| 197 | + task group exit naturally. |
| 198 | + """ |
| 199 | + try: |
| 200 | + async with anyio.create_task_group() as tg: |
| 201 | + self._tg = tg |
| 202 | + self._reader_ready.set() |
| 203 | + tg.start_soon(self._read_messages) |
| 204 | + except BaseException as exc: |
| 205 | + if not self._reader_ready.is_set(): |
| 206 | + self._reader_start_exc = exc |
| 207 | + self._reader_ready.set() |
| 208 | + finally: |
| 209 | + self._tg = None |
| 210 | + self._reader_done.set() |
171 | 211 |
|
172 | 212 | async def _read_messages(self) -> None: |
173 | 213 | """Read messages from transport and route them.""" |
@@ -657,19 +697,28 @@ async def receive_messages(self) -> AsyncIterator[dict[str, Any]]: |
657 | 697 | yield message |
658 | 698 |
|
659 | 699 | async def close(self) -> None: |
660 | | - """Close the query and transport.""" |
| 700 | + """Close the query and transport. |
| 701 | +
|
| 702 | + Closes the transport first so _read_messages exits naturally |
| 703 | + (sentinel/EOF unblocks the queue), then waits for the reader |
| 704 | + task to finish. Falls back to cancellation if the reader does |
| 705 | + not exit within the timeout. |
| 706 | + """ |
661 | 707 | self._closed = True |
662 | | - if self._tg: |
663 | | - self._tg.cancel_scope.cancel() |
664 | | - # Set a deadline to prevent _deliver_cancellation() busy-loop |
665 | | - # when tasks don't respond to cancellation cleanly. |
666 | | - # Uses the task group's own scope (not a nested scope) to avoid |
667 | | - # "not the current cancel scope" errors from anyio. |
668 | | - self._tg.cancel_scope.deadline = anyio.current_time() + 5.0 |
669 | | - with suppress(anyio.get_cancelled_exc_class()): |
670 | | - await self._tg.__aexit__(None, None, None) |
671 | 708 | await self.transport.close() |
672 | 709 |
|
| 710 | + if self._reader_task is not None: |
| 711 | + try: |
| 712 | + await asyncio.wait_for( |
| 713 | + self._reader_done.wait(), |
| 714 | + timeout=_READER_SHUTDOWN_TIMEOUT, |
| 715 | + ) |
| 716 | + except asyncio.TimeoutError: |
| 717 | + self._reader_task.cancel() |
| 718 | + with suppress(asyncio.CancelledError): |
| 719 | + await self._reader_task |
| 720 | + self._reader_task = None |
| 721 | + |
673 | 722 | # Make Query an async iterator |
674 | 723 | def __aiter__(self) -> AsyncIterator[dict[str, Any]]: |
675 | 724 | """Return async iterator for messages.""" |
|
0 commit comments