Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/claude_agent_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
Message,
PermissionMode,
ResultMessage,
TaskNotificationMessage,
TaskStartedMessage,
)


Expand Down Expand Up @@ -74,6 +76,14 @@ def __init__(
self._transport: Transport | None = None
self._query: Any | None = None

# Turn tracking for background-task notification hygiene.
# Each ResultMessage increments _current_turn. When a TaskStartedMessage
# arrives, we record task_id → turn so that receive_response() can
# suppress TaskNotificationMessages whose task was started in a prior
# turn and would otherwise leak into the next turn's response stream.
self._current_turn: int = 0
self._task_turn_map: dict[str, int] = {}

def _convert_hooks_to_internal_format(
self, hooks: dict[HookEvent, list[HookMatcher]]
) -> dict[str, list[dict[str, Any]]]:
Expand Down Expand Up @@ -524,10 +534,62 @@ async def receive_response(self) -> AsyncIterator[Message]:
Note:
To collect all messages: `messages = [msg async for msg in client.receive_response()]`
The final message in the list will always be a ResultMessage.

Background task notifications:
If a background task (spawned via the Agent tool with run_in_background=True)
completes after a previous turn's ResultMessage but before this call returns,
its TaskNotificationMessage is suppressed from this iterator. The notification
arrived between turns and would otherwise appear before the first assistant
response, making it look like stale context from a prior conversation.

Task completions that arrive *during* the current turn (after the first
assistant message) are still yielded normally. For the full unfiltered stream
including all task events, use receive_messages() instead.
"""
if not self._query:
raise CLIConnectionError("Not connected. Call connect() first.")

# We hold any task-lifecycle events that arrive before the first
# non-task message of this turn. Once a non-task message arrives we
# know the CLI is processing our latest query, so deferred events are
# re-yielded in order. Events for tasks started in a previous turn
# are discarded at that point because they are stale cross-turn noise.
deferred: list[Message] = []
turn_started = False

async for message in self.receive_messages():
# Track task IDs so we know which turn they were spawned in.
if isinstance(message, TaskStartedMessage):
self._task_turn_map[message.task_id] = self._current_turn

if not turn_started:
if isinstance(message, (TaskStartedMessage, TaskNotificationMessage)):
# Arrival before the first non-task message: could be
# a stale notification from a previous turn. Defer.
deferred.append(message)
continue

# First non-task message — we are now inside the current turn.
turn_started = True
for deferred_msg in deferred:
if isinstance(deferred_msg, TaskNotificationMessage):
task_turn = self._task_turn_map.get(deferred_msg.task_id)
# Clean up the map entry regardless of outcome.
self._task_turn_map.pop(deferred_msg.task_id, None)
if task_turn is not None and task_turn < self._current_turn:
# Stale: started in a previous turn, completed
# between turns. Drop it.
continue
yield deferred_msg
deferred.clear()

# Clean up map entries when a notification is yielded mid-turn.
if isinstance(message, TaskNotificationMessage):
self._task_turn_map.pop(message.task_id, None)

yield message
if isinstance(message, ResultMessage):
self._current_turn += 1
return

async def disconnect(self) -> None:
Expand Down
210 changes: 210 additions & 0 deletions tests/test_streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
UserMessage,
query,
)
from claude_agent_sdk.types import TaskNotificationMessage, TaskStartedMessage
from claude_agent_sdk._internal.transport.subprocess_cli import SubprocessCLITransport


Expand Down Expand Up @@ -1312,3 +1313,212 @@ async def mock_receive():
assert isinstance(messages[-1], ResultMessage)

anyio.run(_test)


# ---------------------------------------------------------------------------
# Task notification hygiene tests (issue #788)
# ---------------------------------------------------------------------------

def _make_assistant_msg(text: str = "4") -> dict:
return {
"type": "assistant",
"message": {
"role": "assistant",
"content": [{"type": "text", "text": text}],
"model": "claude-sonnet-4-5",
},
}


def _make_result_msg() -> dict:
return {
"type": "result",
"subtype": "success",
"duration_ms": 100,
"duration_api_ms": 80,
"is_error": False,
"num_turns": 1,
"session_id": "test",
"total_cost_usd": 0.001,
}


def _make_task_started(task_id: str = "task-1") -> dict:
return {
"type": "system",
"subtype": "task_started",
"task_id": task_id,
"description": "background work",
"uuid": f"uuid-{task_id}",
"session_id": "test",
}


def _make_task_notification(task_id: str = "task-1") -> dict:
return {
"type": "system",
"subtype": "task_notification",
"task_id": task_id,
"status": "completed",
"output_file": "/tmp/out.md",
"summary": "done",
"uuid": f"notif-{task_id}",
"session_id": "test",
}


class TestReceiveResponseTaskNotificationHygiene:
"""receive_response() must not leak between-turn task notifications (issue #788)."""

def _make_transport_with_messages(self, messages: list[dict]):
"""Build a mock transport that yields the given messages after init."""
mock_transport = AsyncMock()
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)

written_messages: list[str] = []

async def mock_write(data):
written_messages.append(data)

mock_transport.write = AsyncMock(side_effect=mock_write)

async def msg_gen():
# Respond to initialize request first
await asyncio.sleep(0.01)
for msg_str in written_messages:
try:
msg = json.loads(msg_str.strip())
if (
msg.get("type") == "control_request"
and msg.get("request", {}).get("subtype") == "initialize"
):
yield {
"type": "control_response",
"response": {
"request_id": msg.get("request_id"),
"subtype": "success",
"commands": [],
},
}
break
except (json.JSONDecodeError, KeyError):
pass
for m in messages:
yield m

mock_transport.read_messages = msg_gen
return mock_transport

def test_stale_notification_before_turn2_is_suppressed(self):
"""TaskNotificationMessage buffered before Turn 2 starts is NOT yielded."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
# Stream: Turn 1 (task starts + result), then stale notification,
# then Turn 2 (assistant + result).
msgs = [
_make_task_started("t1"),
_make_assistant_msg("Spawning"),
_make_result_msg(), # End Turn 1
_make_task_notification("t1"), # Stale: between turns
_make_assistant_msg("4"),
_make_result_msg(), # End Turn 2
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
# Consume Turn 1
turn1 = [m async for m in client.receive_response()]
assert any(isinstance(m, TaskStartedMessage) for m in turn1)
assert isinstance(turn1[-1], ResultMessage)

# Turn 2: stale notification must not appear
turn2 = [m async for m in client.receive_response()]
assert not any(
isinstance(m, TaskNotificationMessage) for m in turn2
), "Stale TaskNotificationMessage leaked into Turn 2"
assert any(isinstance(m, AssistantMessage) for m in turn2)
assert isinstance(turn2[-1], ResultMessage)

anyio.run(_test)

def test_notification_arriving_mid_turn_is_yielded(self):
"""TaskNotificationMessage that arrives after the first AssistantMessage IS yielded."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
# Stream: Turn 1 starts, notification arrives after first assistant msg.
msgs = [
_make_task_started("t2"),
_make_assistant_msg("thinking..."),
_make_task_notification("t2"), # Arrives mid-turn — should show
_make_assistant_msg("done"),
_make_result_msg(),
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
turn1 = [m async for m in client.receive_response()]

notifications = [m for m in turn1 if isinstance(m, TaskNotificationMessage)]
assert len(notifications) == 1, (
"TaskNotificationMessage that arrived mid-turn should be yielded"
)

anyio.run(_test)

def test_turn_counter_increments_and_cleans_map(self):
"""_current_turn increments per result; _task_turn_map is cleaned up."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
msgs = [
_make_task_started("t3"),
_make_task_notification("t3"), # Completes in Turn 1
_make_assistant_msg("hi"),
_make_result_msg(),
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
assert client._current_turn == 0
_ = [m async for m in client.receive_response()]
assert client._current_turn == 1
# Map entry cleaned up after notification was processed
assert "t3" not in client._task_turn_map

anyio.run(_test)

def test_unknown_task_id_notification_is_yielded(self):
"""Notification for an unknown task_id (no TaskStartedMessage seen) is yielded."""

async def _test():
with patch(
"claude_agent_sdk._internal.transport.subprocess_cli.SubprocessCLITransport"
) as mock_cls:
# No task_started, just a notification — must not crash or suppress.
msgs = [
_make_task_notification("unknown-task"),
_make_assistant_msg("hi"),
_make_result_msg(),
]
mock_cls.return_value = self._make_transport_with_messages(msgs)

async with ClaudeSDKClient() as client:
turn1 = [m async for m in client.receive_response()]

notifications = [m for m in turn1 if isinstance(m, TaskNotificationMessage)]
assert len(notifications) == 1, (
"Notification for unknown task_id should be yielded as current-turn"
)

anyio.run(_test)