Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/claude_agent_sdk/_internal/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,15 +614,16 @@ async def stop_task(self, task_id: str) -> None:
async def wait_for_result_and_end_input(self) -> None:
"""Wait for the first result (if needed) then close stdin.

If SDK MCP servers or hooks require bidirectional communication,
keeps stdin open until the first result arrives (or timeout).
Otherwise closes stdin immediately.
If SDK MCP servers, hooks, or can_use_tool require bidirectional
communication, keeps stdin open until the first result arrives
(or timeout). Otherwise closes stdin immediately.
"""
if self.sdk_mcp_servers or self.hooks:
if self.sdk_mcp_servers or self.hooks or self.can_use_tool:
logger.debug(
"Waiting for first result before closing stdin "
f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, "
f"has_hooks={bool(self.hooks)})"
f"has_hooks={bool(self.hooks)}, "
f"has_can_use_tool={self.can_use_tool is not None})"
)
with anyio.move_on_after(self._stream_close_timeout):
await self._first_result_event.wait()
Expand All @@ -632,8 +633,9 @@ async def wait_for_result_and_end_input(self) -> None:
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
"""Stream input messages to transport.

If SDK MCP servers or hooks are present, waits for the first result
before closing stdin to allow bidirectional control protocol communication.
If SDK MCP servers, hooks, or can_use_tool are present, waits for the
first result before closing stdin to allow bidirectional control
protocol communication.
"""
try:
async for message in stream:
Expand Down
148 changes: 143 additions & 5 deletions tests/test_query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Tests for query() stdin lifecycle with SDK MCP servers and hooks.
"""Tests for query() stdin lifecycle with SDK MCP servers, hooks, and tool callbacks.

The SDK communicates with the CLI subprocess over stdin/stdout. When SDK MCP
servers or hooks are configured, the CLI sends control_request messages back
to the SDK *after* the prompt is written. The SDK must keep stdin open long
enough to respond to these requests. These tests verify that both the string
prompt and AsyncIterable prompt paths defer closing stdin until the CLI's
servers, hooks, or can_use_tool are configured, the CLI sends control_request
messages back to the SDK *after* the prompt is written. The SDK must keep stdin
open long enough to respond to these requests. These tests verify that both the
string prompt and AsyncIterable prompt paths defer closing stdin until the CLI's
first result arrives.
"""

Expand All @@ -16,6 +16,7 @@
from claude_agent_sdk import (
AssistantMessage,
ClaudeAgentOptions,
PermissionResultAllow,
ResultMessage,
create_sdk_mcp_server,
query,
Expand Down Expand Up @@ -103,6 +104,20 @@ async def mock_receive():
]


_CAN_USE_TOOL_CONTROL_REQUESTS = [
{
"type": "control_request",
"request_id": "perm_1",
"request": {
"subtype": "can_use_tool",
"tool_name": "Read",
"input": {"file_path": "foo.txt"},
"permission_suggestions": [],
},
},
]


def _make_greet_server():
@tool("greet", "Greet a user", {"name": str})
async def greet_tool(args):
Expand Down Expand Up @@ -370,6 +385,129 @@ async def prompt_stream():

anyio.run(_test)

def test_async_iterable_with_can_use_tool_waits_for_result(self):
"""AsyncIterable prompt path should wait for first result before
closing stdin when can_use_tool is configured."""

async def _test():
mock_transport = _make_mock_transport(messages=_ASSISTANT_AND_RESULT)

call_order = []
original_write = mock_transport.write

async def tracking_write(data):
call_order.append(("write", data))
return await original_write(data)

async def tracking_end_input():
call_order.append(("end_input",))

mock_transport.write = tracking_write
mock_transport.end_input = tracking_end_input

async def allow_callback(tool_name, tool_input, context):
return PermissionResultAllow()

async def prompt_stream():
yield {
"type": "user",
"message": {"role": "user", "content": "Hello"},
}

with (
patch(
"claude_agent_sdk._internal.client.SubprocessCLITransport"
) as mock_cls,
patch(
"claude_agent_sdk._internal.query.Query.initialize",
new_callable=AsyncMock,
),
):
mock_cls.return_value = mock_transport

messages = []
async for msg in query(
prompt=prompt_stream(),
options=ClaudeAgentOptions(
can_use_tool=allow_callback,
),
):
messages.append(msg)

assert len(messages) == 2
assert isinstance(messages[0], AssistantMessage)
assert isinstance(messages[1], ResultMessage)
assert any(c[0] == "end_input" for c in call_order)

anyio.run(_test)

def test_async_iterable_can_use_tool_control_requests_succeed(self):
"""can_use_tool control requests should be handled correctly when using
AsyncIterable prompts."""

async def _test():
mock_transport = AsyncMock()
writes = []

async def tracking_write(data):
writes.append(data)

mock_transport.write = tracking_write
mock_transport.connect = AsyncMock()
mock_transport.close = AsyncMock()
mock_transport.end_input = AsyncMock()
mock_transport.is_ready = Mock(return_value=True)

async def mock_receive():
for req in _CAN_USE_TOOL_CONTROL_REQUESTS:
yield req
for msg in _ASSISTANT_AND_RESULT:
yield msg

mock_transport.read_messages = mock_receive

async def allow_callback(tool_name, tool_input, context):
return PermissionResultAllow()

async def prompt_stream():
yield {
"type": "user",
"message": {"role": "user", "content": "Read foo.txt"},
}

with (
patch(
"claude_agent_sdk._internal.client.SubprocessCLITransport"
) as mock_cls,
patch(
"claude_agent_sdk._internal.query.Query.initialize",
new_callable=AsyncMock,
),
):
mock_cls.return_value = mock_transport

messages = []
async for msg in query(
prompt=prompt_stream(),
options=ClaudeAgentOptions(
can_use_tool=allow_callback,
),
):
messages.append(msg)

assert len(messages) == 2
assert isinstance(messages[0], AssistantMessage)
assert isinstance(messages[1], ResultMessage)

control_responses = [
json.loads(w.rstrip("\n")) for w in writes if "control_response" in w
]
assert len(control_responses) == 1
assert control_responses[0]["response"]["subtype"] == "success"
assert control_responses[0]["response"]["response"]["behavior"] == "allow"

anyio.run(_test)

def test_async_iterable_mcp_control_requests_succeed(self):
"""MCP control requests should be handled correctly when using
AsyncIterable prompts with SDK MCP servers."""
Expand Down