From 4c93ae7211c2898a1cac951efa70abb2a4365747 Mon Sep 17 00:00:00 2001 From: Max Flanagan Date: Sun, 5 Apr 2026 15:09:25 -0400 Subject: [PATCH] fix(client): suppress stale task notifications at the start of receive_response() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a background task (spawned via run_in_background=True) completed between turns, its TaskNotificationMessage sat in the message buffer and was the first thing yielded by the next receive_response() call. This caused the notification to appear before the actual Turn N+1 response — and in some cases caused the model to respond to the stale task context instead of the new user prompt. Fix: ClaudeSDKClient now tracks which turn each background task was started in (_task_turn_map). receive_response() defers any task lifecycle events that arrive before the first non-task message of the current turn. When the first substantive message arrives, deferred events are flushed — unless the event is a TaskNotificationMessage for a task started in an earlier turn, in which case it is discarded as stale cross-turn noise. Notifications that arrive AFTER the first AssistantMessage (mid-turn) are still yielded normally. Notifications for tasks with no recorded start (unknown task_id) are yielded as current-turn (safe default). Map entries are cleaned up when a notification is processed to prevent unbounded growth on long-lived clients. The raw receive_messages() stream is unchanged: callers who need every event regardless of turn boundaries should use that method instead. Closes #788 --- src/claude_agent_sdk/client.py | 62 ++++++++++ tests/test_streaming_client.py | 210 +++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+) diff --git a/src/claude_agent_sdk/client.py b/src/claude_agent_sdk/client.py index c6ad1171..ba4b287c 100644 --- a/src/claude_agent_sdk/client.py +++ b/src/claude_agent_sdk/client.py @@ -17,6 +17,8 @@ Message, PermissionMode, ResultMessage, + TaskNotificationMessage, + TaskStartedMessage, ) @@ -74,6 +76,14 @@ def __init__( self._transport: Transport | None = None self._query: Any | None = None + # Turn tracking for background-task notification hygiene. + # Each ResultMessage increments _current_turn. When a TaskStartedMessage + # arrives, we record task_id → turn so that receive_response() can + # suppress TaskNotificationMessages whose task was started in a prior + # turn and would otherwise leak into the next turn's response stream. + self._current_turn: int = 0 + self._task_turn_map: dict[str, int] = {} + def _convert_hooks_to_internal_format( self, hooks: dict[HookEvent, list[HookMatcher]] ) -> dict[str, list[dict[str, Any]]]: @@ -524,10 +534,62 @@ async def receive_response(self) -> AsyncIterator[Message]: Note: To collect all messages: `messages = [msg async for msg in client.receive_response()]` The final message in the list will always be a ResultMessage. + + Background task notifications: + If a background task (spawned via the Agent tool with run_in_background=True) + completes after a previous turn's ResultMessage but before this call returns, + its TaskNotificationMessage is suppressed from this iterator. The notification + arrived between turns and would otherwise appear before the first assistant + response, making it look like stale context from a prior conversation. + + Task completions that arrive *during* the current turn (after the first + assistant message) are still yielded normally. For the full unfiltered stream + including all task events, use receive_messages() instead. """ + if not self._query: + raise CLIConnectionError("Not connected. Call connect() first.") + + # We hold any task-lifecycle events that arrive before the first + # non-task message of this turn. Once a non-task message arrives we + # know the CLI is processing our latest query, so deferred events are + # re-yielded in order. Events for tasks started in a previous turn + # are discarded at that point because they are stale cross-turn noise. + deferred: list[Message] = [] + turn_started = False + async for message in self.receive_messages(): + # Track task IDs so we know which turn they were spawned in. + if isinstance(message, TaskStartedMessage): + self._task_turn_map[message.task_id] = self._current_turn + + if not turn_started: + if isinstance(message, (TaskStartedMessage, TaskNotificationMessage)): + # Arrival before the first non-task message: could be + # a stale notification from a previous turn. Defer. + deferred.append(message) + continue + + # First non-task message — we are now inside the current turn. + turn_started = True + for deferred_msg in deferred: + if isinstance(deferred_msg, TaskNotificationMessage): + task_turn = self._task_turn_map.get(deferred_msg.task_id) + # Clean up the map entry regardless of outcome. + self._task_turn_map.pop(deferred_msg.task_id, None) + if task_turn is not None and task_turn < self._current_turn: + # Stale: started in a previous turn, completed + # between turns. Drop it. + continue + yield deferred_msg + deferred.clear() + + # Clean up map entries when a notification is yielded mid-turn. + if isinstance(message, TaskNotificationMessage): + self._task_turn_map.pop(message.task_id, None) + yield message if isinstance(message, ResultMessage): + self._current_turn += 1 return async def disconnect(self) -> None: diff --git a/tests/test_streaming_client.py b/tests/test_streaming_client.py index 1c2b6980..89d60ebf 100644 --- a/tests/test_streaming_client.py +++ b/tests/test_streaming_client.py @@ -20,6 +20,7 @@ UserMessage, query, ) +from claude_agent_sdk.types import TaskNotificationMessage, TaskStartedMessage from claude_agent_sdk._internal.transport.subprocess_cli import SubprocessCLITransport @@ -1312,3 +1313,212 @@ async def mock_receive(): assert isinstance(messages[-1], ResultMessage) anyio.run(_test) + + +# --------------------------------------------------------------------------- +# Task notification hygiene tests (issue #788) +# --------------------------------------------------------------------------- + +def _make_assistant_msg(text: str = "4") -> dict: + return { + "type": "assistant", + "message": { + "role": "assistant", + "content": [{"type": "text", "text": text}], + "model": "claude-sonnet-4-5", + }, + } + + +def _make_result_msg() -> dict: + return { + "type": "result", + "subtype": "success", + "duration_ms": 100, + "duration_api_ms": 80, + "is_error": False, + "num_turns": 1, + "session_id": "test", + "total_cost_usd": 0.001, + } + + +def _make_task_started(task_id: str = "task-1") -> dict: + return { + "type": "system", + "subtype": "task_started", + "task_id": task_id, + "description": "background work", + "uuid": f"uuid-{task_id}", + "session_id": "test", + } + + +def _make_task_notification(task_id: str = "task-1") -> dict: + return { + "type": "system", + "subtype": "task_notification", + "task_id": task_id, + "status": "completed", + "output_file": "/tmp/out.md", + "summary": "done", + "uuid": f"notif-{task_id}", + "session_id": "test", + } + + +class TestReceiveResponseTaskNotificationHygiene: + """receive_response() must not leak between-turn task notifications (issue #788).""" + + def _make_transport_with_messages(self, messages: list[dict]): + """Build a mock transport that yields the given messages after init.""" + mock_transport = AsyncMock() + mock_transport.connect = AsyncMock() + mock_transport.close = AsyncMock() + mock_transport.end_input = AsyncMock() + mock_transport.is_ready = Mock(return_value=True) + + written_messages: list[str] = [] + + async def mock_write(data): + written_messages.append(data) + + mock_transport.write = AsyncMock(side_effect=mock_write) + + async def msg_gen(): + # Respond to initialize request first + await asyncio.sleep(0.01) + for msg_str in written_messages: + try: + msg = json.loads(msg_str.strip()) + if ( + msg.get("type") == "control_request" + and msg.get("request", {}).get("subtype") == "initialize" + ): + yield { + "type": "control_response", + "response": { + "request_id": msg.get("request_id"), + "subtype": "success", + "commands": [], + }, + } + break + except (json.JSONDecodeError, KeyError): + pass + for m in messages: + yield m + + mock_transport.read_messages = msg_gen + return mock_transport + + def test_stale_notification_before_turn2_is_suppressed(self): + """TaskNotificationMessage buffered before Turn 2 starts is NOT yielded.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_cls: + # Stream: Turn 1 (task starts + result), then stale notification, + # then Turn 2 (assistant + result). + msgs = [ + _make_task_started("t1"), + _make_assistant_msg("Spawning"), + _make_result_msg(), # End Turn 1 + _make_task_notification("t1"), # Stale: between turns + _make_assistant_msg("4"), + _make_result_msg(), # End Turn 2 + ] + mock_cls.return_value = self._make_transport_with_messages(msgs) + + async with ClaudeSDKClient() as client: + # Consume Turn 1 + turn1 = [m async for m in client.receive_response()] + assert any(isinstance(m, TaskStartedMessage) for m in turn1) + assert isinstance(turn1[-1], ResultMessage) + + # Turn 2: stale notification must not appear + turn2 = [m async for m in client.receive_response()] + assert not any( + isinstance(m, TaskNotificationMessage) for m in turn2 + ), "Stale TaskNotificationMessage leaked into Turn 2" + assert any(isinstance(m, AssistantMessage) for m in turn2) + assert isinstance(turn2[-1], ResultMessage) + + anyio.run(_test) + + def test_notification_arriving_mid_turn_is_yielded(self): + """TaskNotificationMessage that arrives after the first AssistantMessage IS yielded.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_cls: + # Stream: Turn 1 starts, notification arrives after first assistant msg. + msgs = [ + _make_task_started("t2"), + _make_assistant_msg("thinking..."), + _make_task_notification("t2"), # Arrives mid-turn — should show + _make_assistant_msg("done"), + _make_result_msg(), + ] + mock_cls.return_value = self._make_transport_with_messages(msgs) + + async with ClaudeSDKClient() as client: + turn1 = [m async for m in client.receive_response()] + + notifications = [m for m in turn1 if isinstance(m, TaskNotificationMessage)] + assert len(notifications) == 1, ( + "TaskNotificationMessage that arrived mid-turn should be yielded" + ) + + anyio.run(_test) + + def test_turn_counter_increments_and_cleans_map(self): + """_current_turn increments per result; _task_turn_map is cleaned up.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_cls: + msgs = [ + _make_task_started("t3"), + _make_task_notification("t3"), # Completes in Turn 1 + _make_assistant_msg("hi"), + _make_result_msg(), + ] + mock_cls.return_value = self._make_transport_with_messages(msgs) + + async with ClaudeSDKClient() as client: + assert client._current_turn == 0 + _ = [m async for m in client.receive_response()] + assert client._current_turn == 1 + # Map entry cleaned up after notification was processed + assert "t3" not in client._task_turn_map + + anyio.run(_test) + + def test_unknown_task_id_notification_is_yielded(self): + """Notification for an unknown task_id (no TaskStartedMessage seen) is yielded.""" + + async def _test(): + with patch( + "claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport" + ) as mock_cls: + # No task_started, just a notification — must not crash or suppress. + msgs = [ + _make_task_notification("unknown-task"), + _make_assistant_msg("hi"), + _make_result_msg(), + ] + mock_cls.return_value = self._make_transport_with_messages(msgs) + + async with ClaudeSDKClient() as client: + turn1 = [m async for m in client.receive_response()] + + notifications = [m for m in turn1 if isinstance(m, TaskNotificationMessage)] + assert len(notifications) == 1, ( + "Notification for unknown task_id should be yielded as current-turn" + ) + + anyio.run(_test)