@@ -227,6 +227,9 @@ async def _read_messages(self) -> None:
227227 # Put error in stream so iterators can handle it
228228 await self ._message_send .send ({"type" : "error" , "error" : str (e )})
229229 finally :
230+ # Unblock any waiters (e.g. string-prompt path waiting for first
231+ # result) so they don't stall for the full timeout on early exit.
232+ self ._first_result_event .set ()
230233 # Always signal end of stream
231234 await self ._message_send .send ({"type" : "end" })
232235
@@ -608,6 +611,24 @@ async def stop_task(self, task_id: str) -> None:
608611 }
609612 )
610613
614+ async def wait_for_result_and_end_input (self ) -> None :
615+ """Wait for the first result (if needed) then close stdin.
616+
617+ If SDK MCP servers or hooks require bidirectional communication,
618+ keeps stdin open until the first result arrives (or timeout).
619+ Otherwise closes stdin immediately.
620+ """
621+ if self .sdk_mcp_servers or self .hooks :
622+ logger .debug (
623+ "Waiting for first result before closing stdin "
624+ f"(sdk_mcp_servers={ len (self .sdk_mcp_servers )} , "
625+ f"has_hooks={ bool (self .hooks )} )"
626+ )
627+ with anyio .move_on_after (self ._stream_close_timeout ):
628+ await self ._first_result_event .wait ()
629+
630+ await self .transport .end_input ()
631+
611632 async def stream_input (self , stream : AsyncIterable [dict [str , Any ]]) -> None :
612633 """Stream input messages to transport.
613634
@@ -620,25 +641,7 @@ async def stream_input(self, stream: AsyncIterable[dict[str, Any]]) -> None:
620641 break
621642 await self .transport .write (json .dumps (message ) + "\n " )
622643
623- # If we have SDK MCP servers or hooks that need bidirectional communication,
624- # wait for first result before closing the channel
625- has_hooks = bool (self .hooks )
626- if self .sdk_mcp_servers or has_hooks :
627- logger .debug (
628- f"Waiting for first result before closing stdin "
629- f"(sdk_mcp_servers={ len (self .sdk_mcp_servers )} , has_hooks={ has_hooks } )"
630- )
631- try :
632- with anyio .move_on_after (self ._stream_close_timeout ):
633- await self ._first_result_event .wait ()
634- logger .debug ("Received first result, closing input stream" )
635- except Exception :
636- logger .debug (
637- "Timed out waiting for first result, closing input stream"
638- )
639-
640- # After all messages sent (and result received if needed), end input
641- await self .transport .end_input ()
644+ await self .wait_for_result_and_end_input ()
642645 except Exception as e :
643646 logger .debug (f"Error streaming input: { e } " )
644647
0 commit comments