Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ classifiers = [
keywords = ["claude", "ai", "sdk", "anthropic"]
dependencies = [
"anyio>=4.0.0",
"sniffio>=1.0.0",
"typing_extensions>=4.0.0; python_version<'3.11'",
"mcp>=0.1.0",
]
Expand Down
166 changes: 166 additions & 0 deletions src/claude_agent_sdk/_internal/_task_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Backend-agnostic detached task spawning.

``Query`` manages background tasks (the read loop, ``stream_input``,
control-request handlers) that must be cancellable from any task context
— including async-generator finalizers, which Python may run in a
different task than the one that called ``start()``. anyio's
``TaskGroup`` cannot be used for this because its cancel scope has task
affinity: exiting it from a different task either raises ``RuntimeError:
Attempted to exit cancel scope in a different task than it was entered
in`` or busy-spins in ``_deliver_cancellation`` on the asyncio backend.

Under asyncio this is solved with plain ``loop.create_task()``, but that
raises ``RuntimeError: no running event loop`` under trio. This module
provides ``spawn_detached()`` which dispatches via sniffio to the
appropriate backend primitive, returning a uniform ``TaskHandle``.
"""

from __future__ import annotations

import contextvars
from collections.abc import Callable, Coroutine
from contextlib import suppress
from typing import Any

import sniffio


class TaskHandle:
"""Backend-agnostic handle to a detached background task.

Safe to ``.cancel()`` from any task — no anyio cancel-scope task
affinity.
"""

def cancel(self) -> None:
"""Request cancellation of the wrapped task."""
raise NotImplementedError

def done(self) -> bool:
"""Return True if the wrapped task has finished."""
raise NotImplementedError

def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
"""Register ``callback(self)`` to run when the task finishes."""
raise NotImplementedError

async def wait(self) -> None:
"""Wait for the task to finish.

Suppresses the backend's cancellation exception (the task was
cancelled by us) but re-raises any other exception the task
raised.
"""
raise NotImplementedError


class _AsyncioTaskHandle(TaskHandle):
"""Thin wrapper around ``asyncio.Task``."""

def __init__(self, task: Any) -> None:
self._task = task

def cancel(self) -> None:
self._task.cancel()

def done(self) -> bool:
return bool(self._task.done())

def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
self._task.add_done_callback(lambda _t: callback(self))

async def wait(self) -> None:
import asyncio

with suppress(asyncio.CancelledError):
await self._task


class _TrioTaskHandle(TaskHandle):
"""Wraps a trio system task with its own ``CancelScope``."""

def __init__(self) -> None:
import trio

self._cancel_scope = trio.CancelScope()
self._done_event = trio.Event()
self._exception: BaseException | None = None
self._callbacks: list[Callable[[TaskHandle], None]] = []

def cancel(self) -> None:
# CancelScope.cancel() is sync and safe to call from any task.
self._cancel_scope.cancel()

def done(self) -> bool:
return self._done_event.is_set()

def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None:
if self.done():
callback(self)
else:
self._callbacks.append(callback)

def _mark_done(self, exc: BaseException | None) -> None:
self._exception = exc
self._done_event.set()
for cb in self._callbacks:
# Suppress BaseException so a misbehaving callback can never
# propagate out of the system-task _runner (which would crash
# trio with TrioInternalError). The actual callbacks used here
# are set.discard / dict.pop, so this is purely defensive.
with suppress(BaseException):
cb(self)
self._callbacks.clear()

async def wait(self) -> None:
import trio

await self._done_event.wait()
if self._exception is not None and not isinstance(
self._exception, trio.Cancelled
):
raise self._exception


def spawn_detached(coro: Coroutine[Any, Any, Any]) -> TaskHandle:
"""Spawn ``coro`` as a detached background task on the current backend.

- **asyncio**: ``asyncio.get_running_loop().create_task(coro)``.
- **trio**: ``trio.lowlevel.spawn_system_task`` wrapping ``coro`` in a
per-task ``CancelScope`` so the handle supports ``.cancel()``.
"""
backend = sniffio.current_async_library()
if backend == "asyncio":
import asyncio

loop = asyncio.get_running_loop()
return _AsyncioTaskHandle(loop.create_task(coro))
if backend == "trio":
import trio

handle = _TrioTaskHandle()

async def _runner() -> None:
exc: BaseException | None = None
try:
with handle._cancel_scope:
await coro
except BaseException as e: # noqa: BLE001
# System tasks must not raise (would crash trio). Store
# the exception on the handle; ``.wait()`` re-raises it.
exc = e
finally:
handle._mark_done(exc)
Comment thread
claude[bot] marked this conversation as resolved.
Comment thread
claude[bot] marked this conversation as resolved.

# Pass context= so trio system tasks inherit the caller's
# contextvars (asyncio's loop.create_task() does this implicitly;
# spawn_system_task does not).
trio.lowlevel.spawn_system_task(_runner, context=contextvars.copy_context())
return handle
# Unsupported backend: close the coroutine so we don't leak a "coroutine
# was never awaited" RuntimeWarning on top of the RuntimeError.
coro.close()
raise RuntimeError(
f"Unsupported async backend: {backend!r}. "
"claude_agent_sdk requires asyncio or trio."
)
48 changes: 31 additions & 17 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Query class for handling bidirectional control protocol."""

import asyncio
import json
import logging
import os
Expand All @@ -26,6 +25,7 @@
SDKHookCallbackRequest,
ToolPermissionContext,
)
from ._task_compat import TaskHandle, spawn_detached
from .transport import Transport

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,9 +119,9 @@
self._message_send, self._message_receive = anyio.create_memory_object_stream[
dict[str, Any]
](max_buffer_size=100)
self._read_task: asyncio.Task[None] | None = None
self._child_tasks: set[asyncio.Task[Any]] = set()
self._inflight_requests: dict[str, asyncio.Task[Any]] = {}
self._read_task: TaskHandle | None = None
self._child_tasks: set[TaskHandle] = set()
self._inflight_requests: dict[str, TaskHandle] = {}
self._initialized = False
self._closed = False
self._initialization_result: dict[str, Any] | None = None
Expand Down Expand Up @@ -217,13 +217,11 @@
async def start(self) -> None:
"""Start reading messages from transport."""
if self._read_task is None:
loop = asyncio.get_running_loop()
self._read_task = loop.create_task(self._read_messages())
self._read_task = spawn_detached(self._read_messages())

def spawn_task(self, coro: Any) -> asyncio.Task[Any]:
def spawn_task(self, coro: Any) -> TaskHandle:
"""Spawn a child task that will be cancelled on close()."""
loop = asyncio.get_running_loop()
task = loop.create_task(coro)
task = spawn_detached(coro)
self._child_tasks.add(task)
task.add_done_callback(self._child_tasks.discard)
return task
Expand All @@ -234,7 +232,7 @@
task = self.spawn_task(self._handle_control_request(request))
self._inflight_requests[req_id] = task

def _done(_t: asyncio.Task[Any]) -> None:
def _done(_t: TaskHandle) -> None:
self._inflight_requests.pop(req_id, None)

task.add_done_callback(_done)
Expand Down Expand Up @@ -316,14 +314,23 @@
finally:
# Flush any remaining transcript mirror entries before closing so
# an early stdout EOF or transport error doesn't drop entries
# batched this turn. flush() never raises.
# batched this turn. flush() never raises. Shielded so the await
# still runs when this finally is reached via cancellation.
if self._transcript_mirror_batcher is not None:
await self._transcript_mirror_batcher.flush()
with anyio.CancelScope(shield=True):
await self._transcript_mirror_batcher.flush()
# Unblock any waiters (e.g. string-prompt path waiting for first
# result) so they don't stall for the full timeout on early exit.
self._first_result_event.set()
# Always signal end of stream
await self._message_send.send({"type": "end"})
# Always signal end of stream. send_nowait: trio's level-triggered
# cancellation would re-raise Cancelled at an await checkpoint
# here, dropping the sentinel and leaving receive_messages() hung.
# close() is the fallback for the buffer-full case where
# send_nowait raises WouldBlock — receivers then exit on
# EndOfStream after draining.
with suppress(anyio.WouldBlock):
self._message_send.send_nowait({"type": "end"})
Comment thread
claude[bot] marked this conversation as resolved.
self._message_send.close()

async def _handle_control_request(self, request: SDKControlRequest) -> None:
"""Handle incoming control request from CLI."""
Expand Down Expand Up @@ -426,7 +433,7 @@
}
await self.transport.write(json.dumps(success_response) + "\n")

except asyncio.CancelledError:
except anyio.get_cancelled_exc_class():
# Request was cancelled via control_cancel_request; the CLI has
# already abandoned this request, so don't write a response.
raise
Expand Down Expand Up @@ -808,9 +815,16 @@
task.cancel()
if self._read_task is not None and not self._read_task.done():
self._read_task.cancel()
with suppress(asyncio.CancelledError):
await self._read_task
await self._read_task.wait()
self._read_task = None
# The read task's finally closed the send side; close the receive
# side here so callers get EndOfStream and anyio doesn't emit
# ResourceWarning: Unclosed <MemoryObject*Stream> at GC time.
# _message_send.close() is repeated for the case where start() was
# never called (so _read_messages' finally never ran). Both are
# sync, idempotent, and checkpoint-free.
self._message_send.close()
self._message_receive.close()

Check failure on line 827 in src/claude_agent_sdk/_internal/query.py

View check run for this annotation

Claude / Claude Code Review

_message_receive.close() in Query.close() raises ClosedResourceError for concurrent consumers

Closing `_message_receive` here regresses the asyncio path: anyio's `MemoryObjectReceiveStream.receive_nowait()` checks `if self._closed: raise ClosedResourceError` *before* consulting the buffer, and `__anext__` doesn't translate `ClosedResourceError` into `StopAsyncIteration`. So if a task is iterating `receive_messages()` and is in user code (not parked in `receive()`) when another task calls `close()`/`disconnect()`, its next iteration raises `ClosedResourceError` and any buffered messages +
Comment thread
claude[bot] marked this conversation as resolved.
Outdated
await self.transport.close()

# Make Query an async iterator
Expand Down
66 changes: 66 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,69 @@ async def mock_receive():
)

anyio.run(_test)


class TestClaudeSDKClientTrioBackend:
"""Regression test: ClaudeSDKClient must work under trio.

``Query.start``/``spawn_task`` must not call ``asyncio.get_running_loop()``
(raises ``RuntimeError: no running event loop`` under trio). This test
drives connect()/disconnect() end-to-end on the trio backend with a mock
transport that uses only anyio primitives.
"""

def test_client_connect_under_trio(self):
import json

from claude_agent_sdk import ClaudeSDKClient

def _make_trio_safe_transport():
"""Mock transport using anyio.sleep so it runs under trio."""
mock_transport = AsyncMock()
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)

written: list[str] = []

async def mock_write(data):
written.append(data)

mock_transport.write = AsyncMock(side_effect=mock_write)

async def read_messages():
# Respond to the initialize control_request so connect()
# doesn't block on the 60s timeout.
for _ in range(200):
for msg_str in written:
try:
msg = json.loads(msg_str.strip())
except (json.JSONDecodeError, AttributeError):
continue
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"response": {},
},
}
return
await anyio.sleep(0.01)

mock_transport.read_messages = read_messages
return mock_transport

async def _test():
mock_transport = _make_trio_safe_transport()
async with ClaudeSDKClient(transport=mock_transport) as client:
assert client._transport is mock_transport
mock_transport.connect.assert_called_once()
mock_transport.close.assert_called_once()

anyio.run(_test, backend="trio")
Loading
Loading