Skip to content

Commit 194474c

Browse files
committed
fix(query): restore trio compatibility via sniffio dispatch
PR #746 replaced anyio.TaskGroup with asyncio.create_task() in Query to fix #378 (100% CPU spin) and #454 (cross-task cancel-scope RuntimeError), but asyncio.get_running_loop() raises 'RuntimeError: no running event loop' under trio, breaking ClaudeSDKClient.connect() for trio users. This adds _internal/_task_compat.py with a TaskHandle abstraction that dispatches via sniffio: asyncio uses loop.create_task() (unchanged from PR #746, so #378/#454 stay fixed); trio uses trio.lowlevel.spawn_system_task with a per-task CancelScope. Query.start/spawn_task/close now use TaskHandle and no longer import asyncio. Adds trio-backend tests for Query and ClaudeSDKClient.connect.
1 parent c1eb34e commit 194474c

6 files changed

Lines changed: 462 additions & 14 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ classifiers = [
2626
keywords = ["claude", "ai", "sdk", "anthropic"]
2727
dependencies = [
2828
"anyio>=4.0.0",
29+
"sniffio>=1.0.0",
2930
"typing_extensions>=4.0.0; python_version<'3.11'",
3031
"mcp>=0.1.0",
3132
]
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""Backend-agnostic detached task spawning.
2+
3+
``Query`` manages background tasks (the read loop, ``stream_input``,
4+
control-request handlers) that must be cancellable from any task context
5+
— including async-generator finalizers, which Python may run in a
6+
different task than the one that called ``start()``. anyio's
7+
``TaskGroup`` cannot be used for this because its cancel scope has task
8+
affinity: exiting it from a different task either raises ``RuntimeError:
9+
Attempted to exit cancel scope in a different task than it was entered
10+
in`` or busy-spins in ``_deliver_cancellation`` on the asyncio backend.
11+
12+
Under asyncio this is solved with plain ``loop.create_task()``, but that
13+
raises ``RuntimeError: no running event loop`` under trio. This module
14+
provides ``spawn_detached()`` which dispatches via sniffio to the
15+
appropriate backend primitive, returning a uniform ``TaskHandle``.
16+
"""
17+
18+
from __future__ import annotations
19+
20+
from collections.abc import Callable, Coroutine
21+
from contextlib import suppress
22+
from typing import Any
23+
24+
import sniffio
25+
26+
27+
class TaskHandle:
28+
"""Backend-agnostic handle to a detached background task.
29+
30+
Safe to ``.cancel()`` from any task — no anyio cancel-scope task
31+
affinity.
32+
"""
33+
34+
def cancel(self) -> None:
35+
"""Request cancellation of the wrapped task."""
36+
raise NotImplementedError
37+
38+
def done(self) -> bool:
39+
"""Return True if the wrapped task has finished."""
40+
raise NotImplementedError
41+
42+
def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
43+
"""Register ``callback(self)`` to run when the task finishes."""
44+
raise NotImplementedError
45+
46+
async def wait(self) -> None:
47+
"""Wait for the task to finish.
48+
49+
Suppresses the backend's cancellation exception (the task was
50+
cancelled by us) but re-raises any other exception the task
51+
raised.
52+
"""
53+
raise NotImplementedError
54+
55+
56+
class _AsyncioTaskHandle(TaskHandle):
57+
"""Thin wrapper around ``asyncio.Task``."""
58+
59+
def __init__(self, task: Any) -> None:
60+
self._task = task
61+
62+
def cancel(self) -> None:
63+
self._task.cancel()
64+
65+
def done(self) -> bool:
66+
return bool(self._task.done())
67+
68+
def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
69+
self._task.add_done_callback(lambda _t: callback(self))
70+
71+
async def wait(self) -> None:
72+
import asyncio
73+
74+
with suppress(asyncio.CancelledError):
75+
await self._task
76+
77+
78+
class _TrioTaskHandle(TaskHandle):
79+
"""Wraps a trio system task with its own ``CancelScope``."""
80+
81+
def __init__(self) -> None:
82+
import trio
83+
84+
self._cancel_scope = trio.CancelScope()
85+
self._done_event = trio.Event()
86+
self._exception: BaseException | None = None
87+
self._callbacks: list[Callable[[TaskHandle], None]] = []
88+
89+
def cancel(self) -> None:
90+
# CancelScope.cancel() is sync and safe to call from any task.
91+
self._cancel_scope.cancel()
92+
93+
def done(self) -> bool:
94+
return self._done_event.is_set()
95+
96+
def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
97+
if self.done():
98+
callback(self)
99+
else:
100+
self._callbacks.append(callback)
101+
102+
def _mark_done(self, exc: BaseException | None) -> None:
103+
self._exception = exc
104+
self._done_event.set()
105+
for cb in self._callbacks:
106+
# Suppress BaseException so a misbehaving callback can never
107+
# propagate out of the system-task _runner (which would crash
108+
# trio with TrioInternalError). The actual callbacks used here
109+
# are set.discard / dict.pop, so this is purely defensive.
110+
with suppress(BaseException):
111+
cb(self)
112+
self._callbacks.clear()
113+
114+
async def wait(self) -> None:
115+
import trio
116+
117+
await self._done_event.wait()
118+
if self._exception is not None and not isinstance(
119+
self._exception, trio.Cancelled
120+
):
121+
raise self._exception
122+
123+
124+
def spawn_detached(coro: Coroutine[Any, Any, Any]) -> TaskHandle:
125+
"""Spawn ``coro`` as a detached background task on the current backend.
126+
127+
- **asyncio**: ``asyncio.get_running_loop().create_task(coro)``.
128+
- **trio**: ``trio.lowlevel.spawn_system_task`` wrapping ``coro`` in a
129+
per-task ``CancelScope`` so the handle supports ``.cancel()``.
130+
"""
131+
backend = sniffio.current_async_library()
132+
if backend == "asyncio":
133+
import asyncio
134+
135+
loop = asyncio.get_running_loop()
136+
return _AsyncioTaskHandle(loop.create_task(coro))
137+
if backend == "trio":
138+
import trio
139+
140+
handle = _TrioTaskHandle()
141+
142+
async def _runner() -> None:
143+
exc: BaseException | None = None
144+
try:
145+
with handle._cancel_scope:
146+
await coro
147+
except BaseException as e: # noqa: BLE001
148+
# System tasks must not raise (would crash trio). Store
149+
# the exception on the handle; ``.wait()`` re-raises it.
150+
exc = e
151+
finally:
152+
handle._mark_done(exc)
153+
154+
trio.lowlevel.spawn_system_task(_runner)
155+
return handle
156+
# Unsupported backend: close the coroutine so we don't leak a "coroutine
157+
# was never awaited" RuntimeWarning on top of the RuntimeError.
158+
coro.close()
159+
raise RuntimeError(
160+
f"Unsupported async backend: {backend!r}. "
161+
"claude_agent_sdk requires asyncio or trio."
162+
)

src/claude_agent_sdk/_internal/query.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
"""Query class for handling bidirectional control protocol."""
22

3-
import asyncio
43
import json
54
import logging
65
import os
76
import uuid
87
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
9-
from contextlib import suppress
108
from typing import TYPE_CHECKING, Any, Literal
119

1210
import anyio
@@ -26,6 +24,7 @@
2624
SDKHookCallbackRequest,
2725
ToolPermissionContext,
2826
)
27+
from ._task_compat import TaskHandle, spawn_detached
2928
from .transport import Transport
3029

3130
if TYPE_CHECKING:
@@ -119,9 +118,9 @@ def __init__(
119118
self._message_send, self._message_receive = anyio.create_memory_object_stream[
120119
dict[str, Any]
121120
](max_buffer_size=100)
122-
self._read_task: asyncio.Task[None] | None = None
123-
self._child_tasks: set[asyncio.Task[Any]] = set()
124-
self._inflight_requests: dict[str, asyncio.Task[Any]] = {}
121+
self._read_task: TaskHandle | None = None
122+
self._child_tasks: set[TaskHandle] = set()
123+
self._inflight_requests: dict[str, TaskHandle] = {}
125124
self._initialized = False
126125
self._closed = False
127126
self._initialization_result: dict[str, Any] | None = None
@@ -217,13 +216,11 @@ async def initialize(self) -> dict[str, Any] | None:
217216
async def start(self) -> None:
218217
"""Start reading messages from transport."""
219218
if self._read_task is None:
220-
loop = asyncio.get_running_loop()
221-
self._read_task = loop.create_task(self._read_messages())
219+
self._read_task = spawn_detached(self._read_messages())
222220

223-
def spawn_task(self, coro: Any) -> asyncio.Task[Any]:
221+
def spawn_task(self, coro: Any) -> TaskHandle:
224222
"""Spawn a child task that will be cancelled on close()."""
225-
loop = asyncio.get_running_loop()
226-
task = loop.create_task(coro)
223+
task = spawn_detached(coro)
227224
self._child_tasks.add(task)
228225
task.add_done_callback(self._child_tasks.discard)
229226
return task
@@ -234,7 +231,7 @@ def _spawn_control_request_handler(self, request: SDKControlRequest) -> None:
234231
task = self.spawn_task(self._handle_control_request(request))
235232
self._inflight_requests[req_id] = task
236233

237-
def _done(_t: asyncio.Task[Any]) -> None:
234+
def _done(_t: TaskHandle) -> None:
238235
self._inflight_requests.pop(req_id, None)
239236

240237
task.add_done_callback(_done)
@@ -426,7 +423,7 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None:
426423
}
427424
await self.transport.write(json.dumps(success_response) + "\n")
428425

429-
except asyncio.CancelledError:
426+
except anyio.get_cancelled_exc_class():
430427
# Request was cancelled via control_cancel_request; the CLI has
431428
# already abandoned this request, so don't write a response.
432429
raise
@@ -808,8 +805,7 @@ async def close(self) -> None:
808805
task.cancel()
809806
if self._read_task is not None and not self._read_task.done():
810807
self._read_task.cancel()
811-
with suppress(asyncio.CancelledError):
812-
await self._read_task
808+
await self._read_task.wait()
813809
self._read_task = None
814810
await self.transport.close()
815811

tests/test_client.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,69 @@ async def mock_receive():
255255
)
256256

257257
anyio.run(_test)
258+
259+
260+
class TestClaudeSDKClientTrioBackend:
261+
"""Regression test: ClaudeSDKClient must work under trio.
262+
263+
``Query.start``/``spawn_task`` must not call ``asyncio.get_running_loop()``
264+
(raises ``RuntimeError: no running event loop`` under trio). This test
265+
drives connect()/disconnect() end-to-end on the trio backend with a mock
266+
transport that uses only anyio primitives.
267+
"""
268+
269+
def test_client_connect_under_trio(self):
270+
import json
271+
272+
from claude_agent_sdk import ClaudeSDKClient
273+
274+
def _make_trio_safe_transport():
275+
"""Mock transport using anyio.sleep so it runs under trio."""
276+
mock_transport = AsyncMock()
277+
mock_transport.connect = AsyncMock()
278+
mock_transport.close = AsyncMock()
279+
mock_transport.end_input = AsyncMock()
280+
mock_transport.is_ready = Mock(return_value=True)
281+
282+
written: list[str] = []
283+
284+
async def mock_write(data):
285+
written.append(data)
286+
287+
mock_transport.write = AsyncMock(side_effect=mock_write)
288+
289+
async def read_messages():
290+
# Respond to the initialize control_request so connect()
291+
# doesn't block on the 60s timeout.
292+
for _ in range(200):
293+
for msg_str in written:
294+
try:
295+
msg = json.loads(msg_str.strip())
296+
except (json.JSONDecodeError, AttributeError):
297+
continue
298+
if (
299+
msg.get("type") == "control_request"
300+
and msg.get("request", {}).get("subtype") == "initialize"
301+
):
302+
yield {
303+
"type": "control_response",
304+
"response": {
305+
"request_id": msg.get("request_id"),
306+
"subtype": "success",
307+
"response": {},
308+
},
309+
}
310+
return
311+
await anyio.sleep(0.01)
312+
313+
mock_transport.read_messages = read_messages
314+
return mock_transport
315+
316+
async def _test():
317+
mock_transport = _make_trio_safe_transport()
318+
async with ClaudeSDKClient(transport=mock_transport) as client:
319+
assert client._transport is mock_transport
320+
mock_transport.connect.assert_called_once()
321+
mock_transport.close.assert_called_once()
322+
323+
anyio.run(_test, backend="trio")

tests/test_query.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,79 @@ async def _test():
597597
anyio.run(_test)
598598

599599

600+
class TestQueryTrioBackend:
601+
"""Regression tests for trio compatibility.
602+
603+
``Query`` uses detached background tasks rather than an anyio
604+
``TaskGroup`` (whose cancel scope has task affinity). The asyncio
605+
implementation of that (``loop.create_task()``) raises ``RuntimeError``
606+
under trio; these tests run start/spawn_task/close on the trio backend
607+
to guard the sniffio-dispatch path.
608+
"""
609+
610+
def test_start_and_close_under_trio(self):
611+
"""start() + close() under trio must not raise."""
612+
613+
async def _test():
614+
mock_transport = _make_mock_transport(messages=[])
615+
q = Query(transport=mock_transport, is_streaming_mode=True)
616+
617+
await q.start()
618+
await q.close()
619+
620+
assert q._read_task is None
621+
mock_transport.close.assert_called_once()
622+
623+
anyio.run(_test, backend="trio")
624+
625+
def test_spawn_task_and_cancel_under_trio(self):
626+
"""spawn_task() under trio tracks and cancels child tasks on close()."""
627+
628+
async def _test():
629+
mock_transport = _make_mock_transport(messages=[])
630+
q = Query(transport=mock_transport, is_streaming_mode=True)
631+
632+
await q.start()
633+
634+
async def _slow():
635+
await anyio.sleep(10)
636+
637+
handle = q.spawn_task(_slow())
638+
assert handle in q._child_tasks
639+
640+
await q.close()
641+
# close() cancels child tasks; give the system task a tick to
642+
# fire its done callback that removes it from the set.
643+
await anyio.sleep(0)
644+
assert len(q._child_tasks) == 0
645+
646+
anyio.run(_test, backend="trio")
647+
648+
def test_close_from_different_task_under_trio(self):
649+
"""close() from a different task than start() must not raise (trio)."""
650+
651+
async def _test():
652+
mock_transport = _make_mock_transport(messages=[])
653+
q = Query(transport=mock_transport, is_streaming_mode=True)
654+
655+
await q.start()
656+
657+
close_error = []
658+
659+
async def close_in_other_task():
660+
try:
661+
await q.close()
662+
except Exception as e:
663+
close_error.append(e)
664+
665+
async with anyio.create_task_group() as tg:
666+
tg.start_soon(close_in_other_task)
667+
668+
assert close_error == [], f"close() raised: {close_error}"
669+
670+
anyio.run(_test, backend="trio")
671+
672+
600673
class TestControlCancelRequest:
601674
"""Tests for control_cancel_request handling (issue #739).
602675

0 commit comments

Comments
 (0)