5959from loguru import logger
6060
6161MessageHandler = Callable [[str ], Awaitable [None ]]
62+ STDIO_SUBPROCESS_STREAM_LIMIT = 8 * 1024 * 1024
6263
6364
6465def _get_aiohttp ():
@@ -73,15 +74,19 @@ def _get_web():
7374 return web
7475
7576
76- def _frame_stdio_payload (payload : str ) -> str :
77- body = payload
78- if body .endswith ("\r \n " ):
79- body = body [:- 2 ]
80- elif body .endswith (("\n " , "\r " )):
81- body = body [:- 1 ]
82- if "\n " in body or "\r " in body :
83- raise ValueError ("STDIO payload 不允许包含原始换行符" )
84- return f"{ body } \n "
77+ def _frame_stdio_payload (payload : str ) -> bytes :
78+ body = payload .encode ("utf-8" )
79+ return f"{ len (body )} \n " .encode ("ascii" ) + body
80+
81+
82+ def _parse_stdio_header (raw_header : bytes ) -> int :
83+ header = raw_header .decode ("ascii" ).strip ()
84+ if not header :
85+ raise ValueError ("STDIO frame header is empty" )
86+ try :
87+ return int (header )
88+ except ValueError as exc :
89+ raise ValueError (f"Invalid STDIO frame header: { header !r} " ) from exc
8590
8691
8792# TODO 一个更好的解决方案?
@@ -169,6 +174,7 @@ async def _start_subprocess_with_retry(self) -> asyncio.subprocess.Process:
169174 stdin = asyncio .subprocess .PIPE ,
170175 stdout = asyncio .subprocess .PIPE ,
171176 stderr = sys .stderr ,
177+ limit = STDIO_SUBPROCESS_STREAM_LIMIT ,
172178 )
173179 except Exception as exc :
174180 last_error = exc
@@ -205,11 +211,11 @@ async def stop(self) -> None:
205211 self ._closed .set ()
206212
207213 async def send (self , payload : str ) -> None :
208- line = _frame_stdio_payload (payload )
214+ frame = _frame_stdio_payload (payload )
209215 if self ._process is not None :
210216 if self ._process .stdin is None :
211217 raise RuntimeError ("STDIO subprocess stdin 不可用" )
212- self ._process .stdin .write (line . encode ( "utf-8" ) )
218+ self ._process .stdin .write (frame )
213219 await self ._process .stdin .drain ()
214220 return
215221
@@ -218,8 +224,11 @@ async def send(self, payload: str) -> None:
218224
219225 def _write () -> None :
220226 assert self ._stdout is not None
221- self ._stdout .write (line )
222- self ._stdout .flush ()
227+ binary_stdout = getattr (self ._stdout , "buffer" , None )
228+ if binary_stdout is None :
229+ raise RuntimeError ("STDIO stdout 必须提供可写入 bytes 的 buffer" )
230+ binary_stdout .write (frame )
231+ binary_stdout .flush ()
223232
224233 await asyncio .to_thread (_write )
225234
@@ -228,21 +237,30 @@ async def _read_process_loop(self) -> None:
228237 assert self ._process .stdout is not None
229238 try :
230239 while True :
231- raw = await self ._process .stdout .readline ()
232- if not raw :
240+ raw_header = await self ._process .stdout .readline ()
241+ if not raw_header :
233242 break
234- await self ._dispatch (raw .decode ("utf-8" ).rstrip ("\r \n " ))
243+ payload_size = _parse_stdio_header (raw_header )
244+ raw = await self ._process .stdout .readexactly (payload_size )
245+ await self ._dispatch (raw .decode ("utf-8" ))
235246 finally :
236247 self ._closed .set ()
237248
238249 async def _read_file_loop (self ) -> None :
239250 assert self ._stdin is not None
240251 try :
241252 while True :
242- raw = await asyncio .to_thread (self ._stdin .readline )
243- if not raw :
253+ binary_stdin = getattr (self ._stdin , "buffer" , None )
254+ if binary_stdin is None :
255+ raise RuntimeError ("STDIO stdin 必须提供可读取 bytes 的 buffer" )
256+ raw_header = await asyncio .to_thread (binary_stdin .readline )
257+ if not raw_header :
244258 break
245- await self ._dispatch (raw .rstrip ("\r \n " ))
259+ payload_size = _parse_stdio_header (raw_header )
260+ raw = await asyncio .to_thread (binary_stdin .read , payload_size )
261+ if len (raw ) != payload_size :
262+ raise EOFError ("STDIO frame truncated before payload completed" )
263+ await self ._dispatch (raw .decode ("utf-8" ))
246264 finally :
247265 self ._closed .set ()
248266
0 commit comments