diff --git a/pyproject.toml b/pyproject.toml index 8b9589a0..016b1984 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ keywords = ["claude", "ai", "sdk", "anthropic"] dependencies = [ "anyio>=4.0.0", + "sniffio>=1.0.0", "typing_extensions>=4.0.0; python_version<'3.11'", "mcp>=0.1.0", ] diff --git a/src/claude_agent_sdk/_internal/_task_compat.py b/src/claude_agent_sdk/_internal/_task_compat.py new file mode 100644 index 00000000..5d339bb5 --- /dev/null +++ b/src/claude_agent_sdk/_internal/_task_compat.py @@ -0,0 +1,176 @@ +"""Backend-agnostic detached task spawning. + +``Query`` manages background tasks (the read loop, ``stream_input``, +control-request handlers) that must be cancellable from any task context +— including async-generator finalizers, which Python may run in a +different task than the one that called ``start()``. anyio's +``TaskGroup`` cannot be used for this because its cancel scope has task +affinity: exiting it from a different task either raises ``RuntimeError: +Attempted to exit cancel scope in a different task than it was entered +in`` or busy-spins in ``_deliver_cancellation`` on the asyncio backend. + +Under asyncio this is solved with plain ``loop.create_task()``, but that +raises ``RuntimeError: no running event loop`` under trio. This module +provides ``spawn_detached()`` which dispatches via sniffio to the +appropriate backend primitive, returning a uniform ``TaskHandle``. +""" + +from __future__ import annotations + +import contextvars +import logging +from collections.abc import Callable, Coroutine +from contextlib import suppress +from typing import Any + +import sniffio + +logger = logging.getLogger(__name__) + + +class TaskHandle: + """Backend-agnostic handle to a detached background task. + + Safe to ``.cancel()`` from any task — no anyio cancel-scope task + affinity. + """ + + def cancel(self) -> None: + """Request cancellation of the wrapped task.""" + raise NotImplementedError + + def done(self) -> bool: + """Return True if the wrapped task has finished.""" + raise NotImplementedError + + def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None: + """Register ``callback(self)`` to run when the task finishes.""" + raise NotImplementedError + + async def wait(self) -> None: + """Wait for the task to finish. + + Suppresses the backend's cancellation exception (the task was + cancelled by us) but re-raises any other exception the task + raised. + """ + raise NotImplementedError + + +class _AsyncioTaskHandle(TaskHandle): + """Thin wrapper around ``asyncio.Task``.""" + + def __init__(self, task: Any) -> None: + self._task = task + + def cancel(self) -> None: + self._task.cancel() + + def done(self) -> bool: + return bool(self._task.done()) + + def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None: + self._task.add_done_callback(lambda _t: callback(self)) + + async def wait(self) -> None: + import asyncio + + with suppress(asyncio.CancelledError): + await self._task + + +class _TrioTaskHandle(TaskHandle): + """Wraps a trio system task with its own ``CancelScope``.""" + + def __init__(self) -> None: + import trio + + self._cancel_scope = trio.CancelScope() + self._done_event = trio.Event() + self._exception: BaseException | None = None + self._callbacks: list[Callable[[TaskHandle], None]] = [] + + def cancel(self) -> None: + # CancelScope.cancel() is sync and safe to call from any task. + self._cancel_scope.cancel() + + def done(self) -> bool: + return self._done_event.is_set() + + def add_done_callback(self, callback: Callable[[TaskHandle], None]) -> None: + if self.done(): + callback(self) + else: + self._callbacks.append(callback) + + def _mark_done(self, exc: BaseException | None) -> None: + import trio + + # Parity with asyncio's "Task exception was never retrieved": + # close() only .cancel()s child tasks (never .wait()s them), so a + # non-Cancelled exception would otherwise be silently dropped. + if exc is not None and not isinstance(exc, trio.Cancelled): + logger.warning("Unhandled exception in detached trio task", exc_info=exc) + self._exception = exc + self._done_event.set() + for cb in self._callbacks: + # Suppress BaseException so a misbehaving callback can never + # propagate out of the system-task _runner (which would crash + # trio with TrioInternalError). The actual callbacks used here + # are set.discard / dict.pop, so this is purely defensive. + with suppress(BaseException): + cb(self) + self._callbacks.clear() + + async def wait(self) -> None: + import trio + + await self._done_event.wait() + if self._exception is not None and not isinstance( + self._exception, trio.Cancelled + ): + raise self._exception + + +def spawn_detached(coro: Coroutine[Any, Any, Any]) -> TaskHandle: + """Spawn ``coro`` as a detached background task on the current backend. + + - **asyncio**: ``asyncio.get_running_loop().create_task(coro)``. + - **trio**: ``trio.lowlevel.spawn_system_task`` wrapping ``coro`` in a + per-task ``CancelScope`` so the handle supports ``.cancel()``. + """ + backend = sniffio.current_async_library() + if backend == "asyncio": + import asyncio + + loop = asyncio.get_running_loop() + return _AsyncioTaskHandle(loop.create_task(coro)) + if backend == "trio": + import trio + + handle = _TrioTaskHandle() + + async def _runner() -> None: + exc: BaseException | None = None + try: + with handle._cancel_scope: + await coro + except BaseException as e: # noqa: BLE001 + # System tasks must not raise (would crash trio). Store + # the exception on the handle; ``.wait()`` re-raises it. + exc = e + finally: + handle._mark_done(exc) + + # Pass context= so trio system tasks inherit the caller's + # contextvars (asyncio's loop.create_task() does this implicitly; + # spawn_system_task does not). + trio.lowlevel.spawn_system_task(_runner, context=contextvars.copy_context()) + return handle + # Unsupported backend: close the coroutine so we don't leak a "coroutine + # was never awaited" RuntimeWarning on top of the RuntimeError. + coro.close() + raise RuntimeError( + f"Unsupported async backend: {backend!r}. " + "claude_agent_sdk requires asyncio or trio." + ) diff --git a/src/claude_agent_sdk/_internal/query.py b/src/claude_agent_sdk/_internal/query.py index d0241cef..0843453e 100644 --- a/src/claude_agent_sdk/_internal/query.py +++ b/src/claude_agent_sdk/_internal/query.py @@ -1,6 +1,5 @@ """Query class for handling bidirectional control protocol.""" -import asyncio import json import logging import os @@ -26,6 +25,7 @@ SDKHookCallbackRequest, ToolPermissionContext, ) +from ._task_compat import TaskHandle, spawn_detached from .transport import Transport if TYPE_CHECKING: @@ -119,9 +119,9 @@ def __init__( self._message_send, self._message_receive = anyio.create_memory_object_stream[ dict[str, Any] ](max_buffer_size=100) - self._read_task: asyncio.Task[None] | None = None - self._child_tasks: set[asyncio.Task[Any]] = set() - self._inflight_requests: dict[str, asyncio.Task[Any]] = {} + self._read_task: TaskHandle | None = None + self._child_tasks: set[TaskHandle] = set() + self._inflight_requests: dict[str, TaskHandle] = {} self._initialized = False self._closed = False self._initialization_result: dict[str, Any] | None = None @@ -217,13 +217,11 @@ async def initialize(self) -> dict[str, Any] | None: async def start(self) -> None: """Start reading messages from transport.""" if self._read_task is None: - loop = asyncio.get_running_loop() - self._read_task = loop.create_task(self._read_messages()) + self._read_task = spawn_detached(self._read_messages()) - def spawn_task(self, coro: Any) -> asyncio.Task[Any]: + def spawn_task(self, coro: Any) -> TaskHandle: """Spawn a child task that will be cancelled on close().""" - loop = asyncio.get_running_loop() - task = loop.create_task(coro) + task = spawn_detached(coro) self._child_tasks.add(task) task.add_done_callback(self._child_tasks.discard) return task @@ -234,7 +232,7 @@ def _spawn_control_request_handler(self, request: SDKControlRequest) -> None: task = self.spawn_task(self._handle_control_request(request)) self._inflight_requests[req_id] = task - def _done(_t: asyncio.Task[Any]) -> None: + def _done(_t: TaskHandle) -> None: self._inflight_requests.pop(req_id, None) task.add_done_callback(_done) @@ -316,14 +314,23 @@ async def _read_messages(self) -> None: finally: # Flush any remaining transcript mirror entries before closing so # an early stdout EOF or transport error doesn't drop entries - # batched this turn. flush() never raises. + # batched this turn. flush() never raises. Shielded so the await + # still runs when this finally is reached via cancellation. if self._transcript_mirror_batcher is not None: - await self._transcript_mirror_batcher.flush() + with anyio.CancelScope(shield=True): + await self._transcript_mirror_batcher.flush() # Unblock any waiters (e.g. string-prompt path waiting for first # result) so they don't stall for the full timeout on early exit. self._first_result_event.set() - # Always signal end of stream - await self._message_send.send({"type": "end"}) + # Always signal end of stream. send_nowait: trio's level-triggered + # cancellation would re-raise Cancelled at an await checkpoint + # here, dropping the sentinel and leaving receive_messages() hung. + # close() is the fallback for the buffer-full case where + # send_nowait raises WouldBlock — receivers then exit on + # EndOfStream after draining. + with suppress(anyio.WouldBlock): + self._message_send.send_nowait({"type": "end"}) + self._message_send.close() async def _handle_control_request(self, request: SDKControlRequest) -> None: """Handle incoming control request from CLI.""" @@ -426,7 +433,7 @@ async def _handle_control_request(self, request: SDKControlRequest) -> None: } await self.transport.write(json.dumps(success_response) + "\n") - except asyncio.CancelledError: + except anyio.get_cancelled_exc_class(): # Request was cancelled via control_cancel_request; the CLI has # already abandoned this request, so don't write a response. raise @@ -808,9 +815,16 @@ async def close(self) -> None: task.cancel() if self._read_task is not None and not self._read_task.done(): self._read_task.cancel() - with suppress(asyncio.CancelledError): - await self._read_task + await self._read_task.wait() self._read_task = None + # The read task's finally closed the send side; repeat here for the + # case where start() was never called. Do NOT close the receive + # side — it belongs to the consumer, and anyio's receive_nowait() + # checks _closed before the buffer, so closing it here would make a + # non-parked consumer drop buffered messages with + # ClosedResourceError. _message_send.close() alone yields + # EndOfStream after the buffer drains. + self._message_send.close() await self.transport.close() # Make Query an async iterator diff --git a/tests/test_client.py b/tests/test_client.py index f9065557..dbf15a82 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, Mock, patch import anyio +import pytest from claude_agent_sdk import AssistantMessage, ClaudeAgentOptions, query from claude_agent_sdk.types import TextBlock @@ -255,3 +256,72 @@ async def mock_receive(): ) anyio.run(_test) + + +@pytest.mark.filterwarnings( + "ignore:Unclosed None: + async def _test(): + with anyio.fail_after(5.0): + mock_transport = self._make_blocking_transport() + q = Query(transport=mock_transport, is_streaming_mode=True) + await q.start() + + # Buffer 3 messages directly (bypassing the read task, + # which is blocked on the transport). + for i in range(3): + q._message_send.send_nowait({"type": "user", "i": i}) + + consumed: list[dict] = [] + consumer_error: list[BaseException] = [] + got_first = anyio.Event() + in_user_code = anyio.Event() + + async def consumer(): + try: + async for msg in q.receive_messages(): + consumed.append(msg) + if len(consumed) == 1: + got_first.set() + # Stay in user code (NOT parked in + # receive()) while close() runs. + await in_user_code.wait() + except BaseException as e: # noqa: BLE001 + consumer_error.append(e) + + async with anyio.create_task_group() as tg: + tg.start_soon(consumer) + await got_first.wait() + # Consumer is now awaiting in_user_code (user code), + # with 2 messages still buffered. + await q.close() + in_user_code.set() + + assert consumer_error == [], ( + f"[{backend}] consumer raised: {consumer_error}" + ) + assert len(consumed) == 3, ( + f"[{backend}] expected 3 messages, got {len(consumed)}: {consumed}" + ) + + anyio.run(_test, backend=backend) + + def test_buffered_messages_drain_after_close_asyncio(self): + """Consumer in user code when close() runs must drain the buffer. + + anyio's ``receive_nowait()`` checks ``_closed`` before the buffer, + so closing ``_message_receive`` from ``close()`` would make a + non-parked consumer hit ``ClosedResourceError`` and drop buffered + messages. ``_message_send.close()`` alone yields ``EndOfStream`` + only after the buffer drains. + """ + self._run_buffered_drain_after_close("asyncio") + + def test_buffered_messages_drain_after_close_trio(self): + """trio parity for the buffered-drain-after-close test above.""" + self._run_buffered_drain_after_close("trio") + + class TestControlCancelRequest: """Tests for control_cancel_request handling (issue #739). diff --git a/tests/test_task_compat.py b/tests/test_task_compat.py new file mode 100644 index 00000000..eadad4ef --- /dev/null +++ b/tests/test_task_compat.py @@ -0,0 +1,214 @@ +"""Tests for the backend-agnostic detached task spawner.""" + +import anyio +import pytest + +from claude_agent_sdk._internal._task_compat import spawn_detached + + +class TestSpawnAndWait: + def test_spawn_and_wait_asyncio(self): + async def _test(): + flag = {"set": False} + + async def coro(): + flag["set"] = True + + handle = spawn_detached(coro()) + await handle.wait() + assert flag["set"] is True + assert handle.done() is True + + anyio.run(_test, backend="asyncio") + + def test_spawn_and_wait_trio(self): + async def _test(): + flag = {"set": False} + + async def coro(): + flag["set"] = True + + handle = spawn_detached(coro()) + await handle.wait() + assert flag["set"] is True + assert handle.done() is True + + anyio.run(_test, backend="trio") + + +class TestCancel: + def test_cancel_asyncio(self): + async def _test(): + async def coro(): + await anyio.sleep(3600) + + handle = spawn_detached(coro()) + await anyio.sleep(0) # let it start + handle.cancel() + await handle.wait() + assert handle.done() is True + + anyio.run(_test, backend="asyncio") + + def test_cancel_trio(self): + async def _test(): + async def coro(): + await anyio.sleep(3600) + + handle = spawn_detached(coro()) + await anyio.sleep(0) # let it start + handle.cancel() + await handle.wait() + assert handle.done() is True + + anyio.run(_test, backend="trio") + + +class TestDoneCallback: + def test_done_callback_asyncio(self): + async def _test(): + fired_with = [] + + async def coro(): + pass + + handle = spawn_detached(coro()) + handle.add_done_callback(fired_with.append) + await handle.wait() + await anyio.sleep(0) # let callback fire + assert fired_with == [handle] + + anyio.run(_test, backend="asyncio") + + def test_done_callback_trio(self): + async def _test(): + fired_with = [] + + async def coro(): + pass + + handle = spawn_detached(coro()) + handle.add_done_callback(fired_with.append) + await handle.wait() + assert fired_with == [handle] + + anyio.run(_test, backend="trio") + + +class TestExceptionPropagation: + def test_exception_propagates_via_wait_asyncio(self): + async def _test(): + async def coro(): + raise ValueError("boom") + + handle = spawn_detached(coro()) + with pytest.raises(ValueError, match="boom"): + await handle.wait() + + anyio.run(_test, backend="asyncio") + + def test_exception_propagates_via_wait_trio(self): + async def _test(): + async def coro(): + raise ValueError("boom") + + handle = spawn_detached(coro()) + with pytest.raises(ValueError, match="boom"): + await handle.wait() + + anyio.run(_test, backend="trio") + + def test_unhandled_exception_logged_under_trio(self, caplog): + """A non-Cancelled exception with no .wait() must still be logged. + + Parity with asyncio's "Task exception was never retrieved": child + tasks that are only ``.cancel()``ed (never ``.wait()``ed) would + otherwise drop the exception silently under trio. + """ + import logging + + async def _test(): + async def coro(): + raise ValueError("boom") + + spawn_detached(coro()) # NB: no .wait() + await anyio.sleep(0) # let it run + + with caplog.at_level( + logging.WARNING, logger="claude_agent_sdk._internal._task_compat" + ): + anyio.run(_test, backend="trio") + + assert any( + "Unhandled exception in detached trio task" in r.message + for r in caplog.records + ), f"expected warning log, got: {[r.message for r in caplog.records]}" + assert any( + r.exc_info and isinstance(r.exc_info[1], ValueError) for r in caplog.records + ) + + +class TestContextVarPropagation: + """Spawned tasks must see the caller's contextvars on both backends. + + asyncio's ``loop.create_task()`` copies the current context implicitly; + trio's ``spawn_system_task`` does not unless ``context=`` is passed. + """ + + @staticmethod + def _run(backend: str) -> str: + import contextvars + + cv: contextvars.ContextVar[str] = contextvars.ContextVar( + "cv", default="DEFAULT" + ) + seen: list[str] = [] + + async def _test(): + cv.set("PARENT") + + async def coro(): + seen.append(cv.get()) + + handle = spawn_detached(coro()) + await handle.wait() + + anyio.run(_test, backend=backend) + return seen[0] + + def test_contextvar_propagates_asyncio(self): + assert self._run("asyncio") == "PARENT" + + def test_contextvar_propagates_trio(self): + assert self._run("trio") == "PARENT" + + +class TestCrossTaskCancel: + def test_cancel_from_different_task_trio(self): + """Cancelling from a different task than the spawner must not raise. + + This is the trio-side equivalent of the cross-task-cancel invariant. + """ + + async def _test(): + async def coro(): + await anyio.sleep(3600) + + handle = spawn_detached(coro()) + await anyio.sleep(0) + cancel_error = [] + + async def cancel_in_other_task(): + try: + handle.cancel() + await handle.wait() + except Exception as e: # pragma: no cover - failure path + cancel_error.append(e) + + async with anyio.create_task_group() as tg: + tg.start_soon(cancel_in_other_task) + + assert cancel_error == [], f"cancel raised: {cancel_error}" + assert handle.done() is True + + anyio.run(_test, backend="trio")