11import asyncio
22import inspect
33from collections .abc import AsyncGenerator
4- from dataclasses import dataclass
54from typing import TYPE_CHECKING
65
76from astrbot .core import astrbot_config , logger
5049STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC = 30
5150
5251
52+ def _coerce_positive_int (value : object , default : int ) -> int :
53+ if isinstance (value , bool ):
54+ return default
55+ try :
56+ coerced = int (value )
57+ except (TypeError , ValueError ):
58+ return default
59+ return coerced if coerced > 0 else default
60+
61+
5362def _set_runner_error_extra (event : "AstrMessageEvent" , is_error : bool ) -> None :
5463 event .set_extra (THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY , is_error )
5564
@@ -91,7 +100,7 @@ async def run_third_party_agent(
91100 runner : "BaseAgentRunner" ,
92101 stream_to_general : bool = False ,
93102 custom_error_message : str | None = None ,
94- ) -> AsyncGenerator ["_ThirdPartyRunnerOutput" , None ]:
103+ ) -> AsyncGenerator [tuple [ MessageChain , bool ] , None ]:
95104 """
96105 运行第三方 agent runner 并转换响应格式
97106 类似于 run_agent 函数,但专门处理第三方 agent runner
@@ -101,21 +110,12 @@ async def run_third_party_agent(
101110 if resp .type == "streaming_delta" :
102111 if stream_to_general :
103112 continue
104- yield _ThirdPartyRunnerOutput (
105- chain = resp .data ["chain" ],
106- is_error = False ,
107- )
113+ yield resp .data ["chain" ], False
108114 elif resp .type == "llm_result" :
109115 if stream_to_general :
110- yield _ThirdPartyRunnerOutput (
111- chain = resp .data ["chain" ],
112- is_error = False ,
113- )
116+ yield resp .data ["chain" ], False
114117 elif resp .type == "err" :
115- yield _ThirdPartyRunnerOutput (
116- chain = resp .data ["chain" ],
117- is_error = True ,
118- )
118+ yield resp .data ["chain" ], True
119119 except Exception as e :
120120 logger .error (f"Third party agent runner error: { e } " )
121121 err_msg = custom_error_message
@@ -125,16 +125,26 @@ async def run_third_party_agent(
125125 f"Error Type: { type (e ).__name__ } (3rd party)\n "
126126 f"Error Message: { str (e )} "
127127 )
128- yield _ThirdPartyRunnerOutput (
129- chain = MessageChain ().message (err_msg ),
130- is_error = True ,
131- )
128+ yield MessageChain ().message (err_msg ), True
132129
133130
134- @dataclass
135- class _ThirdPartyRunnerOutput :
136- chain : MessageChain
137- is_error : bool = False
131+ async def _consume_runner_and_aggregate (
132+ runner : "BaseAgentRunner" ,
133+ * ,
134+ stream_to_general : bool ,
135+ custom_error_message : str | None ,
136+ ) -> AsyncGenerator [tuple [MessageChain , bool , list , bool ], None ]:
137+ merged_chain : list = []
138+ has_intermediate_error = False
139+ async for chain , is_error in run_third_party_agent (
140+ runner ,
141+ stream_to_general = stream_to_general ,
142+ custom_error_message = custom_error_message ,
143+ ):
144+ merged_chain .extend (chain .chain or [])
145+ if is_error :
146+ has_intermediate_error = True
147+ yield chain , is_error , merged_chain , has_intermediate_error
138148
139149
140150async def _close_runner_if_supported (runner : "BaseAgentRunner" ) -> None :
@@ -151,8 +161,15 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None:
151161
152162
153163class _RunnerLifecycle :
154- def __init__ (self , runner : "BaseAgentRunner" ) -> None :
164+ def __init__ (
165+ self ,
166+ runner : "BaseAgentRunner" ,
167+ stream_consumption_close_timeout_sec : int ,
168+ ) -> None :
155169 self ._runner = runner
170+ self ._stream_consumption_close_timeout_sec = (
171+ stream_consumption_close_timeout_sec
172+ )
156173 self ._closed = False
157174 self ._stream_started = False
158175 self ._stream_consumed = False
@@ -196,22 +213,34 @@ async def finalize(self) -> None:
196213 if (
197214 self ._idle_task
198215 and not self ._idle_task .done ()
199- and (not self . _stream_started or self ._stream_consumed or self ._closed )
216+ and (self ._stream_consumed or self ._closed )
200217 ):
201218 self ._idle_task .cancel ()
202219
203- if not self ._stream_started :
220+ defer_close_to_watchdog = (
221+ self ._stream_started
222+ and not self ._stream_consumed
223+ and self ._idle_task is not None
224+ and not self ._idle_task .done ()
225+ and not self ._closed
226+ )
227+ if defer_close_to_watchdog :
228+ return
229+
230+ if not self ._closed :
204231 await self .close_once ()
205232
206233 async def _close_if_never_consumed (self ) -> None :
207234 try :
208- await asyncio .sleep (STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC )
235+ await asyncio .sleep (self . _stream_consumption_close_timeout_sec )
209236 except asyncio .CancelledError :
210237 return
211238
212239 if not self ._stream_consumed :
213240 logger .warning (
214- "Third-party runner stream was never consumed; closing runner to avoid resource leak." ,
241+ "Third-party runner stream was never consumed in %ss; closing runner "
242+ "to avoid resource leak." ,
243+ self ._stream_consumption_close_timeout_sec ,
215244 )
216245 await self .close_once ()
217246
@@ -230,6 +259,13 @@ async def initialize(self, ctx: PipelineContext) -> None:
230259 self .unsupported_streaming_strategy : str = settings [
231260 "unsupported_streaming_strategy"
232261 ]
262+ self .stream_consumption_close_timeout_sec : int = _coerce_positive_int (
263+ settings .get (
264+ "third_party_stream_consumption_close_timeout_sec" ,
265+ STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC ,
266+ ),
267+ STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC ,
268+ )
233269
234270 async def _resolve_persona_custom_error_message (
235271 self , event : AstrMessageEvent
@@ -258,20 +294,25 @@ async def _handle_streaming_response(
258294 custom_error_message : str | None ,
259295 ) -> AsyncGenerator [None , None ]:
260296 stream_has_runner_error = False
297+ merged_chain : list = []
261298
262299 async def _stream_runner_chain () -> AsyncGenerator [MessageChain , None ]:
263- nonlocal stream_has_runner_error
300+ nonlocal merged_chain , stream_has_runner_error
264301 lifecycle .mark_stream_consumed ()
265302 try :
266- async for runner_output in run_third_party_agent (
303+ async for (
304+ chain ,
305+ is_error ,
306+ merged_chain ,
307+ stream_has_runner_error ,
308+ ) in _consume_runner_and_aggregate (
267309 runner ,
268310 stream_to_general = False ,
269311 custom_error_message = custom_error_message ,
270312 ):
271- if runner_output .is_error :
272- stream_has_runner_error = True
313+ if is_error :
273314 _set_runner_error_extra (event , True )
274- yield runner_output . chain
315+ yield chain
275316 finally :
276317 # Streaming runner cleanup must happen after consumer
277318 # finishes iterating to avoid tearing down active streams.
@@ -293,7 +334,7 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
293334 is_runner_error ,
294335 _ ,
295336 ) = _resolve_final_result (
296- merged_chain = [] ,
337+ merged_chain = merged_chain ,
297338 final_resp = final_resp ,
298339 has_intermediate_error = stream_has_runner_error ,
299340 )
@@ -315,14 +356,16 @@ async def _handle_non_streaming_response(
315356 ) -> AsyncGenerator [None , None ]:
316357 merged_chain : list = []
317358 has_intermediate_error = False
318- async for output in run_third_party_agent (
359+ async for (
360+ _ ,
361+ _ ,
362+ merged_chain ,
363+ has_intermediate_error ,
364+ ) in _consume_runner_and_aggregate (
319365 runner ,
320366 stream_to_general = stream_to_general ,
321367 custom_error_message = custom_error_message ,
322368 ):
323- merged_chain .extend (output .chain .chain or [])
324- if output .is_error :
325- has_intermediate_error = True
326369 yield
327370
328371 final_resp = runner .get_final_llm_resp ()
@@ -421,7 +464,10 @@ async def process(
421464 and not event .platform_meta .support_streaming_message
422465 )
423466
424- lifecycle = _RunnerLifecycle (runner )
467+ lifecycle = _RunnerLifecycle (
468+ runner ,
469+ stream_consumption_close_timeout_sec = self .stream_consumption_close_timeout_sec ,
470+ )
425471
426472 try :
427473 await lifecycle .reset (
0 commit comments