Skip to content

Commit 62c8bfd

Browse files
authored
fix(query): restore trio compatibility via sniffio dispatch (#870)
## Problem PR #746 (v0.1.51+) replaced `anyio.TaskGroup` with `asyncio.create_task()` in `Query` to fix #378 (100% CPU spin in `_deliver_cancellation`) and #454 (cross-task cancel-scope `RuntimeError`). However, `asyncio.get_running_loop()` raises `RuntimeError: no running event loop` under trio, breaking `ClaudeSDKClient.connect()` for trio users since v0.1.51: ```python import trio from claude_agent_sdk import ClaudeSDKClient async def main(): async with ClaudeSDKClient() as c: # RuntimeError: no running event loop ... trio.run(main) ``` ## Approach **sniffio dispatch.** Adds `_internal/_task_compat.py` with a `TaskHandle` abstraction and `spawn_detached(coro)`: - **asyncio** → `loop.create_task()` wrapped in `_AsyncioTaskHandle`. Behaviorally identical to PR #746 — `cancel()`, `done()`, `add_done_callback()`, and `wait()` are thin pass-throughs to `asyncio.Task`. **#378/#454 stay fixed.** - **trio** → `trio.lowlevel.spawn_system_task` with a per-task `CancelScope` wrapped in `_TrioTaskHandle`. `CancelScope.cancel()` is sync and has no task affinity, so `close()` from any task is safe (the #454 invariant holds for trio too). `Query.start/spawn_task/close` and `_spawn_control_request_handler` now use `TaskHandle`. `query.py` no longer imports `asyncio`; the two cancellation-exception sites use `anyio.get_cancelled_exc_class()`. The full anyio-TaskGroup restructure was previously attempted in #364 and proved tricky; this change keeps the asyncio path untouched to minimize regression risk. ## Why `trio.lowlevel.spawn_system_task`? trio has no `create_task()` equivalent by design (structured concurrency). `spawn_system_task` is the documented escape hatch for detached tasks. Each spawned coro is wrapped in `try/except BaseException` so a failure can never propagate as `TrioInternalError`; the exception is stored on the handle and re-raised by `wait()`. ## Out of scope (follow-ups) `_internal/session_resume.py`, `_internal/transcript_mirror_batcher.py`, and `_internal/sessions.py` also have direct `asyncio` usage. These are opt-in features gated behind `options.session_store` and were never trio-compatible — not regressions from #746. Tracked separately. ## Testing - New `tests/test_task_compat.py` — 9 unit tests, both backends (spawn/wait, cancel, done-callback, exception propagation, cross-task cancel) - New `TestQueryTrioBackend` (3 tests) — `start`/`close`/`spawn_task` under `anyio.run(..., backend="trio")` - New `TestClaudeSDKClientTrioBackend::test_client_connect_under_trio` — the repro above as a unit test - Existing `TestQueryCrossTaskCleanup` (#454 guard) and `TestControlCancelRequest` (#751 guard) still pass - 748 passed, 3 skipped; ruff + mypy clean - Manual e2e: real query under `trio.run()` against live CLI returns `AssistantMessage` + `ResultMessage(success)` ## Deps Adds `sniffio>=1.0.0` to runtime deps (already a transitive dep of `anyio>=4.0.0`; just made explicit).
1 parent c1eb34e commit 62c8bfd

6 files changed

Lines changed: 715 additions & 17 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ classifiers = [
2626
keywords = ["claude", "ai", "sdk", "anthropic"]
2727
dependencies = [
2828
"anyio>=4.0.0",
29+
"sniffio>=1.0.0",
2930
"typing_extensions>=4.0.0; python_version<'3.11'",
3031
"mcp>=0.1.0",
3132
]
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
"""Backend-agnostic detached task spawning.
2+
3+
``Query`` manages background tasks (the read loop, ``stream_input``,
4+
control-request handlers) that must be cancellable from any task context
5+
— including async-generator finalizers, which Python may run in a
6+
different task than the one that called ``start()``. anyio's
7+
``TaskGroup`` cannot be used for this because its cancel scope has task
8+
affinity: exiting it from a different task either raises ``RuntimeError:
9+
Attempted to exit cancel scope in a different task than it was entered
10+
in`` or busy-spins in ``_deliver_cancellation`` on the asyncio backend.
11+
12+
Under asyncio this is solved with plain ``loop.create_task()``, but that
13+
raises ``RuntimeError: no running event loop`` under trio. This module
14+
provides ``spawn_detached()`` which dispatches via sniffio to the
15+
appropriate backend primitive, returning a uniform ``TaskHandle``.
16+
"""
17+
18+
from __future__ import annotations
19+
20+
import contextvars
21+
import logging
22+
from collections.abc import Callable, Coroutine
23+
from contextlib import suppress
24+
from typing import Any
25+
26+
import sniffio
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class TaskHandle:
32+
"""Backend-agnostic handle to a detached background task.
33+
34+
Safe to ``.cancel()`` from any task — no anyio cancel-scope task
35+
affinity.
36+
"""
37+
38+
def cancel(self) -> None:
39+
"""Request cancellation of the wrapped task."""
40+
raise NotImplementedError
41+
42+
def done(self) -> bool:
43+
"""Return True if the wrapped task has finished."""
44+
raise NotImplementedError
45+
46+
def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
47+
"""Register ``callback(self)`` to run when the task finishes."""
48+
raise NotImplementedError
49+
50+
async def wait(self) -> None:
51+
"""Wait for the task to finish.
52+
53+
Suppresses the backend's cancellation exception (the task was
54+
cancelled by us) but re-raises any other exception the task
55+
raised.
56+
"""
57+
raise NotImplementedError
58+
59+
60+
class _AsyncioTaskHandle(TaskHandle):
61+
"""Thin wrapper around ``asyncio.Task``."""
62+
63+
def __init__(self, task: Any) -> None:
64+
self._task = task
65+
66+
def cancel(self) -> None:
67+
self._task.cancel()
68+
69+
def done(self) -> bool:
70+
return bool(self._task.done())
71+
72+
def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
73+
self._task.add_done_callback(lambda _t: callback(self))
74+
75+
async def wait(self) -> None:
76+
import asyncio
77+
78+
with suppress(asyncio.CancelledError):
79+
await self._task
80+
81+
82+
class _TrioTaskHandle(TaskHandle):
83+
"""Wraps a trio system task with its own ``CancelScope``."""
84+
85+
def __init__(self) -> None:
86+
import trio
87+
88+
self._cancel_scope = trio.CancelScope()
89+
self._done_event = trio.Event()
90+
self._exception: BaseException | None = None
91+
self._callbacks: list[Callable[[TaskHandle], None]] = []
92+
93+
def cancel(self) -> None:
94+
# CancelScope.cancel() is sync and safe to call from any task.
95+
self._cancel_scope.cancel()
96+
97+
def done(self) -> bool:
98+
return self._done_event.is_set()
99+
100+
def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
101+
if self.done():
102+
callback(self)
103+
else:
104+
self._callbacks.append(callback)
105+
106+
def _mark_done(self, exc: BaseException | None) -> None:
107+
import trio
108+
109+
# Parity with asyncio's "Task exception was never retrieved":
110+
# close() only .cancel()s child tasks (never .wait()s them), so a
111+
# non-Cancelled exception would otherwise be silently dropped.
112+
if exc is not None and not isinstance(exc, trio.Cancelled):
113+
logger.warning("Unhandled exception in detached trio task", exc_info=exc)
114+
self._exception = exc
115+
self._done_event.set()
116+
for cb in self._callbacks:
117+
# Suppress BaseException so a misbehaving callback can never
118+
# propagate out of the system-task _runner (which would crash
119+
# trio with TrioInternalError). The actual callbacks used here
120+
# are set.discard / dict.pop, so this is purely defensive.
121+
with suppress(BaseException):
122+
cb(self)
123+
self._callbacks.clear()
124+
125+
async def wait(self) -> None:
126+
import trio
127+
128+
await self._done_event.wait()
129+
if self._exception is not None and not isinstance(
130+
self._exception, trio.Cancelled
131+
):
132+
raise self._exception
133+
134+
135+
def spawn_detached(coro: Coroutine[Any, Any, Any]) -> TaskHandle:
136+
"""Spawn ``coro`` as a detached background task on the current backend.
137+
138+
- **asyncio**: ``asyncio.get_running_loop().create_task(coro)``.
139+
- **trio**: ``trio.lowlevel.spawn_system_task`` wrapping ``coro`` in a
140+
per-task ``CancelScope`` so the handle supports ``.cancel()``.
141+
"""
142+
backend = sniffio.current_async_library()
143+
if backend == "asyncio":
144+
import asyncio
145+
146+
loop = asyncio.get_running_loop()
147+
return _AsyncioTaskHandle(loop.create_task(coro))
148+
if backend == "trio":
149+
import trio
150+
151+
handle = _TrioTaskHandle()
152+
153+
async def _runner() -> None:
154+
exc: BaseException | None = None
155+
try:
156+
with handle._cancel_scope:
157+
await coro
158+
except BaseException as e: # noqa: BLE001
159+
# System tasks must not raise (would crash trio). Store
160+
# the exception on the handle; ``.wait()`` re-raises it.
161+
exc = e
162+
finally:
163+
handle._mark_done(exc)
164+
165+
# Pass context= so trio system tasks inherit the caller's
166+
# contextvars (asyncio's loop.create_task() does this implicitly;
167+
# spawn_system_task does not).
168+
trio.lowlevel.spawn_system_task(_runner, context=contextvars.copy_context())
169+
return handle
170+
# Unsupported backend: close the coroutine so we don't leak a "coroutine
171+
# was never awaited" RuntimeWarning on top of the RuntimeError.
172+
coro.close()
173+
raise RuntimeError(
174+
f"Unsupported async backend: {backend!r}. "
175+
"claude_agent_sdk requires asyncio or trio."
176+
)

src/claude_agent_sdk/_internal/query.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Query class for handling bidirectional control protocol."""
22

3-
import asyncio
43
import json
54
import logging
65
import os
@@ -26,6 +25,7 @@
2625
SDKHookCallbackRequest,
2726
ToolPermissionContext,
2827
)
28+
from ._task_compat import TaskHandle, spawn_detached
2929
from .transport import Transport
3030

3131
if TYPE_CHECKING:
@@ -119,9 +119,9 @@ def __init__(
119119
self._message_send, self._message_receive = anyio.create_memory_object_stream[
120120
dict[str, Any]
121121
](max_buffer_size=100)
122-
self._read_task: asyncio.Task[None] | None = None
123-
self._child_tasks: set[asyncio.Task[Any]] = set()
124-
self._inflight_requests: dict[str, asyncio.Task[Any]] = {}
122+
self._read_task: TaskHandle | None = None
123+
self._child_tasks: set[TaskHandle] = set()
124+
self._inflight_requests: dict[str, TaskHandle] = {}
125125
self._initialized = False
126126
self._closed = False
127127
self._initialization_result: dict[str, Any] | None = None
@@ -217,13 +217,11 @@ async def initialize(self) -> dict[str, Any] | None:
217217
async def start(self) -> None:
218218
"""Start reading messages from transport."""
219219
if self._read_task is None:
220-
loop = asyncio.get_running_loop()
221-
self._read_task = loop.create_task(self._read_messages())
220+
self._read_task = spawn_detached(self._read_messages())
222221

223-
def spawn_task(self, coro: Any) -> asyncio.Task[Any]:
222+
def spawn_task(self, coro: Any) -> TaskHandle:
224223
"""Spawn a child task that will be cancelled on close()."""
225-
loop = asyncio.get_running_loop()
226-
task = loop.create_task(coro)
224+
task = spawn_detached(coro)
227225
self._child_tasks.add(task)
228226
task.add_done_callback(self._child_tasks.discard)
229227
return task
@@ -234,7 +232,7 @@ def _spawn_control_request_handler(self, request: SDKControlRequest) -> None:
234232
task = self.spawn_task(self._handle_control_request(request))
235233
self._inflight_requests[req_id] = task
236234

237-
def _done(_t: asyncio.Task[Any]) -> None:
235+
def _done(_t: TaskHandle) -> None:
238236
self._inflight_requests.pop(req_id, None)
239237

240238
task.add_done_callback(_done)
@@ -316,14 +314,23 @@ async def _read_messages(self) -> None:
316314
finally:
317315
# Flush any remaining transcript mirror entries before closing so
318316
# an early stdout EOF or transport error doesn't drop entries
319-
# batched this turn. flush() never raises.
317+
# batched this turn. flush() never raises. Shielded so the await
318+
# still runs when this finally is reached via cancellation.
320319
if self._transcript_mirror_batcher is not None:
321-
await self._transcript_mirror_batcher.flush()
320+
with anyio.CancelScope(shield=True):
321+
await self._transcript_mirror_batcher.flush()
322322
# Unblock any waiters (e.g. string-prompt path waiting for first
323323
# result) so they don't stall for the full timeout on early exit.
324324
self._first_result_event.set()
325-
# Always signal end of stream
326-
await self._message_send.send({"type": "end"})
325+
# Always signal end of stream. send_nowait: trio's level-triggered
326+
# cancellation would re-raise Cancelled at an await checkpoint
327+
# here, dropping the sentinel and leaving receive_messages() hung.
328+
# close() is the fallback for the buffer-full case where
329+
# send_nowait raises WouldBlock — receivers then exit on
330+
# EndOfStream after draining.
331+
with suppress(anyio.WouldBlock):
332+
self._message_send.send_nowait({"type": "end"})
333+
self._message_send.close()
327334

328335
async def _handle_control_request(self, request: SDKControlRequest) -> None:
329336
"""Handle incoming control request from CLI."""
@@ -426,7 +433,7 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None:
426433
}
427434
await self.transport.write(json.dumps(success_response) + "\n")
428435

429-
except asyncio.CancelledError:
436+
except anyio.get_cancelled_exc_class():
430437
# Request was cancelled via control_cancel_request; the CLI has
431438
# already abandoned this request, so don't write a response.
432439
raise
@@ -808,9 +815,16 @@ async def close(self) -> None:
808815
task.cancel()
809816
if self._read_task is not None and not self._read_task.done():
810817
self._read_task.cancel()
811-
with suppress(asyncio.CancelledError):
812-
await self._read_task
818+
await self._read_task.wait()
813819
self._read_task = None
820+
# The read task's finally closed the send side; repeat here for the
821+
# case where start() was never called. Do NOT close the receive
822+
# side — it belongs to the consumer, and anyio's receive_nowait()
823+
# checks _closed before the buffer, so closing it here would make a
824+
# non-parked consumer drop buffered messages with
825+
# ClosedResourceError. _message_send.close() alone yields
826+
# EndOfStream after the buffer drains.
827+
self._message_send.close()
814828
await self.transport.close()
815829

816830
# Make Query an async iterator

tests/test_client.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from unittest.mock import AsyncMock, Mock, patch
55

66
import anyio
7+
import pytest
78

89
from claude_agent_sdk import AssistantMessage, ClaudeAgentOptions, query
910
from claude_agent_sdk.types import TextBlock
@@ -255,3 +256,72 @@ async def mock_receive():
255256
)
256257

257258
anyio.run(_test)
259+
260+
261+
@pytest.mark.filterwarnings(
262+
"ignore:Unclosed <MemoryObjectReceiveStream:ResourceWarning"
263+
)
264+
class TestClaudeSDKClientTrioBackend:
265+
"""Regression test: ClaudeSDKClient must work under trio.
266+
267+
``Query.start``/``spawn_task`` must not call ``asyncio.get_running_loop()``
268+
(raises ``RuntimeError: no running event loop`` under trio). This test
269+
drives connect()/disconnect() end-to-end on the trio backend with a mock
270+
transport that uses only anyio primitives.
271+
"""
272+
273+
def test_client_connect_under_trio(self):
274+
import json
275+
276+
from claude_agent_sdk import ClaudeSDKClient
277+
278+
def _make_trio_safe_transport():
279+
"""Mock transport using anyio.sleep so it runs under trio."""
280+
mock_transport = AsyncMock()
281+
mock_transport.connect = AsyncMock()
282+
mock_transport.close = AsyncMock()
283+
mock_transport.end_input = AsyncMock()
284+
mock_transport.is_ready = Mock(return_value=True)
285+
286+
written: list[str] = []
287+
288+
async def mock_write(data):
289+
written.append(data)
290+
291+
mock_transport.write = AsyncMock(side_effect=mock_write)
292+
293+
async def read_messages():
294+
# Respond to the initialize control_request so connect()
295+
# doesn't block on the 60s timeout.
296+
for _ in range(200):
297+
for msg_str in written:
298+
try:
299+
msg = json.loads(msg_str.strip())
300+
except (json.JSONDecodeError, AttributeError):
301+
continue
302+
if (
303+
msg.get("type") == "control_request"
304+
and msg.get("request", {}).get("subtype") == "initialize"
305+
):
306+
yield {
307+
"type": "control_response",
308+
"response": {
309+
"request_id": msg.get("request_id"),
310+
"subtype": "success",
311+
"response": {},
312+
},
313+
}
314+
return
315+
await anyio.sleep(0.01)
316+
317+
mock_transport.read_messages = read_messages
318+
return mock_transport
319+
320+
async def _test():
321+
mock_transport = _make_trio_safe_transport()
322+
async with ClaudeSDKClient(transport=mock_transport) as client:
323+
assert client._transport is mock_transport
324+
mock_transport.connect.assert_called_once()
325+
mock_transport.close.assert_called_once()
326+
327+
anyio.run(_test, backend="trio")

0 commit comments

Comments
 (0)