Skip to content

Commit 6119fd4

Browse files
authored
Fix string prompt closing stdin before MCP server init completes (#630)
1 parent de7ecca commit 6119fd4

3 files changed

Lines changed: 462 additions & 22 deletions

File tree

src/claude_agent_sdk/_internal/client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Internal client implementation."""
22

3+
import json
34
from collections.abc import AsyncIterable, AsyncIterator
45
from dataclasses import asdict, replace
56
from typing import Any
@@ -122,16 +123,14 @@ async def process_query(
122123
if isinstance(prompt, str):
123124
# For string prompts, write user message to stdin after initialize
124125
# (matching TypeScript SDK behavior)
125-
import json
126-
127126
user_message = {
128127
"type": "user",
129128
"session_id": "",
130129
"message": {"role": "user", "content": prompt},
131130
"parent_tool_use_id": None,
132131
}
133132
await chosen_transport.write(json.dumps(user_message) + "\n")
134-
await chosen_transport.end_input()
133+
await query.wait_for_result_and_end_input()
135134
elif isinstance(prompt, AsyncIterable) and query._tg:
136135
# Stream input in background for async iterables
137136
query._tg.start_soon(query.stream_input, prompt)

src/claude_agent_sdk/_internal/query.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,9 @@ async def _read_messages(self) -> None:
227227
# Put error in stream so iterators can handle it
228228
await self._message_send.send({"type": "error", "error": str(e)})
229229
finally:
230+
# Unblock any waiters (e.g. string-prompt path waiting for first
231+
# result) so they don't stall for the full timeout on early exit.
232+
self._first_result_event.set()
230233
# Always signal end of stream
231234
await self._message_send.send({"type": "end"})
232235

@@ -608,6 +611,24 @@ async def stop_task(self, task_id: str) -> None:
608611
}
609612
)
610613

614+
async def wait_for_result_and_end_input(self) -> None:
615+
"""Wait for the first result (if needed) then close stdin.
616+
617+
If SDK MCP servers or hooks require bidirectional communication,
618+
keeps stdin open until the first result arrives (or timeout).
619+
Otherwise closes stdin immediately.
620+
"""
621+
if self.sdk_mcp_servers or self.hooks:
622+
logger.debug(
623+
"Waiting for first result before closing stdin "
624+
f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, "
625+
f"has_hooks={bool(self.hooks)})"
626+
)
627+
with anyio.move_on_after(self._stream_close_timeout):
628+
await self._first_result_event.wait()
629+
630+
await self.transport.end_input()
631+
611632
async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
612633
"""Stream input messages to transport.
613634
@@ -620,25 +641,7 @@ async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
620641
break
621642
await self.transport.write(json.dumps(message) + "\n")
622643

623-
# If we have SDK MCP servers or hooks that need bidirectional communication,
624-
# wait for first result before closing the channel
625-
has_hooks = bool(self.hooks)
626-
if self.sdk_mcp_servers or has_hooks:
627-
logger.debug(
628-
f"Waiting for first result before closing stdin "
629-
f"(sdk_mcp_servers={len(self.sdk_mcp_servers)}, has_hooks={has_hooks})"
630-
)
631-
try:
632-
with anyio.move_on_after(self._stream_close_timeout):
633-
await self._first_result_event.wait()
634-
logger.debug("Received first result, closing input stream")
635-
except Exception:
636-
logger.debug(
637-
"Timed out waiting for first result, closing input stream"
638-
)
639-
640-
# After all messages sent (and result received if needed), end input
641-
await self.transport.end_input()
644+
await self.wait_for_result_and_end_input()
642645
except Exception as e:
643646
logger.debug(f"Error streaming input: {e}")
644647

0 commit comments

Comments
 (0)