11import asyncio
22import inspect
3- import typing as T
43from collections .abc import AsyncGenerator
5- from contextlib import asynccontextmanager
64from dataclasses import dataclass
75from typing import TYPE_CHECKING
86
3028
3129if TYPE_CHECKING :
3230 from astrbot .core .agent .runners .base import BaseAgentRunner
33- from astrbot .core .provider .entities import LLMResponse
3431from astrbot .core .pipeline .stage import Stage
3532from astrbot .core .platform .astr_message_event import AstrMessageEvent
3633from astrbot .core .provider .entities import (
@@ -55,37 +52,6 @@ def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None:
5552 event .set_extra (THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY , is_error )
5653
5754
58- def _runner_result_content_type (is_error : bool ) -> ResultContentType :
59- return (
60- ResultContentType .AGENT_RUNNER_ERROR
61- if is_error
62- else ResultContentType .LLM_RESULT
63- )
64-
65-
66- def _set_non_stream_runner_result (
67- event : "AstrMessageEvent" ,
68- chain : list ,
69- is_error : bool ,
70- ) -> None :
71- _set_runner_error_extra (event , is_error )
72- event .set_result (
73- MessageEventResult (
74- chain = chain ,
75- result_content_type = _runner_result_content_type (is_error ),
76- ),
77- )
78-
79-
80- def _aggregate_runner_error (
81- has_intermediate_error : bool ,
82- final_resp : "LLMResponse | None" ,
83- ) -> bool :
84- if not final_resp :
85- return has_intermediate_error
86- return has_intermediate_error or final_resp .role == "err"
87-
88-
8955async def run_third_party_agent (
9056 runner : "BaseAgentRunner" ,
9157 stream_to_general : bool = False ,
@@ -149,44 +115,6 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None:
149115 logger .warning (f"Failed to close third-party runner cleanly: { e } " )
150116
151117
152- @asynccontextmanager
153- async def _runner_session (
154- runner : "BaseAgentRunner" ,
155- * ,
156- request : ProviderRequest ,
157- run_context : AgentContextWrapper ,
158- agent_hooks : T .Any ,
159- provider_config : dict ,
160- streaming : bool ,
161- ):
162- runner_closed = False
163- defer_close = False
164-
165- async def close_runner_once () -> None :
166- nonlocal runner_closed
167- if runner_closed :
168- return
169- runner_closed = True
170- await _close_runner_if_supported (runner )
171-
172- def defer_runner_close () -> None :
173- nonlocal defer_close
174- defer_close = True
175-
176- await runner .reset (
177- request = request ,
178- run_context = run_context ,
179- agent_hooks = agent_hooks ,
180- provider_config = provider_config ,
181- streaming = streaming ,
182- )
183- try :
184- yield close_runner_once , defer_runner_close
185- finally :
186- if not defer_close :
187- await close_runner_once ()
188-
189-
190118class ThirdPartyAgentSubStage (Stage ):
191119 async def initialize (self , ctx : PipelineContext ) -> None :
192120 self .ctx = ctx
@@ -220,101 +148,6 @@ async def _resolve_persona_custom_error_message(
220148 logger .debug ("Failed to resolve persona custom error message: %s" , e )
221149 return None
222150
223- async def _handle_streaming_runner (
224- self ,
225- runner : "BaseAgentRunner" ,
226- event : AstrMessageEvent ,
227- custom_error_message : str | None ,
228- close_runner_once : T .Callable [[], T .Awaitable [None ]],
229- ) -> AsyncGenerator [None , None ]:
230- stream_has_runner_error = False
231-
232- async def _stream_runner_chain () -> AsyncGenerator [MessageChain , None ]:
233- nonlocal stream_has_runner_error
234- try :
235- async for runner_output in run_third_party_agent (
236- runner ,
237- stream_to_general = False ,
238- custom_error_message = custom_error_message ,
239- ):
240- if runner_output .is_error :
241- stream_has_runner_error = True
242- _set_runner_error_extra (event , True )
243- yield runner_output .chain
244- finally :
245- # Streaming runner cleanup must happen after consumer
246- # finishes iterating to avoid tearing down active streams.
247- await close_runner_once ()
248-
249- event .set_result (
250- MessageEventResult ()
251- .set_result_content_type (ResultContentType .STREAMING_RESULT )
252- .set_async_stream (_stream_runner_chain ()),
253- )
254- yield
255-
256- if runner .done ():
257- final_resp = runner .get_final_llm_resp ()
258- if final_resp and final_resp .result_chain :
259- is_runner_error = _aggregate_runner_error (
260- has_intermediate_error = stream_has_runner_error ,
261- final_resp = final_resp ,
262- )
263- _set_runner_error_extra (event , is_runner_error )
264- event .set_result (
265- MessageEventResult (
266- chain = final_resp .result_chain .chain or [],
267- result_content_type = ResultContentType .STREAMING_FINISH ,
268- ),
269- )
270-
271- async def _handle_non_streaming_runner (
272- self ,
273- runner : "BaseAgentRunner" ,
274- event : AstrMessageEvent ,
275- stream_to_general : bool ,
276- custom_error_message : str | None ,
277- ) -> AsyncGenerator [None , None ]:
278- merged_chain : list = []
279- has_intermediate_error = False
280- async for output in run_third_party_agent (
281- runner ,
282- stream_to_general = stream_to_general ,
283- custom_error_message = custom_error_message ,
284- ):
285- merged_chain .extend (output .chain .chain or [])
286- if output .is_error :
287- has_intermediate_error = True
288- yield
289-
290- final_resp = runner .get_final_llm_resp ()
291-
292- if not final_resp or not final_resp .result_chain :
293- if merged_chain :
294- logger .warning (
295- "Agent Runner returned no final response, fallback to streamed error/result chain."
296- )
297- _set_non_stream_runner_result (
298- event = event ,
299- chain = merged_chain ,
300- is_error = has_intermediate_error ,
301- )
302- yield
303- return
304- logger .warning ("Agent Runner 未返回最终结果。" )
305- return
306-
307- is_runner_error = _aggregate_runner_error (
308- has_intermediate_error = has_intermediate_error ,
309- final_resp = final_resp ,
310- )
311- _set_non_stream_runner_result (
312- event = event ,
313- chain = final_resp .result_chain .chain or [],
314- is_error = is_runner_error ,
315- )
316- yield
317-
318151 async def process (
319152 self , event : AstrMessageEvent , provider_wake_prefix : str
320153 ) -> AsyncGenerator [None , None ]:
@@ -384,37 +217,129 @@ async def process(
384217 and not event .platform_meta .support_streaming_message
385218 )
386219
387- async with _runner_session (
388- runner = runner ,
389- request = req ,
390- run_context = AgentContextWrapper (
391- context = astr_agent_ctx ,
392- tool_call_timeout = 60 ,
393- ),
394- agent_hooks = MAIN_AGENT_HOOKS ,
395- provider_config = self .prov_cfg ,
396- streaming = streaming_response ,
397- ) as (close_runner_once , defer_runner_close ):
220+ runner_closed = False
221+ streaming_started = False
222+
223+ async def close_runner_once () -> None :
224+ nonlocal runner_closed
225+ if runner_closed :
226+ return
227+ runner_closed = True
228+ await _close_runner_if_supported (runner )
229+
230+ try :
231+ await runner .reset (
232+ request = req ,
233+ run_context = AgentContextWrapper (
234+ context = astr_agent_ctx ,
235+ tool_call_timeout = 60 ,
236+ ),
237+ agent_hooks = MAIN_AGENT_HOOKS ,
238+ provider_config = self .prov_cfg ,
239+ streaming = streaming_response ,
240+ )
241+
398242 if streaming_response and not stream_to_general :
399- stream_started = False
400- async for _ in self ._handle_streaming_runner (
401- runner = runner ,
402- event = event ,
403- custom_error_message = custom_error_message ,
404- close_runner_once = close_runner_once ,
405- ):
406- if not stream_started :
407- defer_runner_close ()
408- stream_started = True
409- yield
243+ stream_has_runner_error = False
244+
245+ async def _stream_runner_chain () -> AsyncGenerator [MessageChain , None ]:
246+ nonlocal stream_has_runner_error
247+ try :
248+ async for runner_output in run_third_party_agent (
249+ runner ,
250+ stream_to_general = False ,
251+ custom_error_message = custom_error_message ,
252+ ):
253+ if runner_output .is_error :
254+ stream_has_runner_error = True
255+ _set_runner_error_extra (event , True )
256+ yield runner_output .chain
257+ finally :
258+ # Streaming runner cleanup must happen after consumer
259+ # finishes iterating to avoid tearing down active streams.
260+ await close_runner_once ()
261+
262+ event .set_result (
263+ MessageEventResult ()
264+ .set_result_content_type (ResultContentType .STREAMING_RESULT )
265+ .set_async_stream (_stream_runner_chain ()),
266+ )
267+ streaming_started = True
268+ yield
269+
270+ if runner .done ():
271+ final_resp = runner .get_final_llm_resp ()
272+ if final_resp and final_resp .result_chain :
273+ is_runner_error = (
274+ stream_has_runner_error or final_resp .role == "err"
275+ )
276+ _set_runner_error_extra (event , is_runner_error )
277+ event .set_result (
278+ MessageEventResult (
279+ chain = final_resp .result_chain .chain or [],
280+ result_content_type = ResultContentType .STREAMING_FINISH ,
281+ ),
282+ )
410283 else :
411- async for _ in self ._handle_non_streaming_runner (
412- runner = runner ,
413- event = event ,
284+ merged_chain : list = []
285+ has_intermediate_error = False
286+ async for output in run_third_party_agent (
287+ runner ,
414288 stream_to_general = stream_to_general ,
415289 custom_error_message = custom_error_message ,
416290 ):
291+ merged_chain .extend (output .chain .chain or [])
292+ if output .is_error :
293+ has_intermediate_error = True
294+ yield
295+
296+ final_resp = runner .get_final_llm_resp ()
297+ if not final_resp or not final_resp .result_chain :
298+ if merged_chain :
299+ logger .warning (
300+ "Agent Runner returned no final response, fallback to streamed error/result chain."
301+ )
302+ _set_runner_error_extra (event , has_intermediate_error )
303+ event .set_result (
304+ MessageEventResult (
305+ chain = merged_chain ,
306+ result_content_type = (
307+ ResultContentType .AGENT_RUNNER_ERROR
308+ if has_intermediate_error
309+ else ResultContentType .LLM_RESULT
310+ ),
311+ ),
312+ )
313+ else :
314+ logger .warning ("Agent Runner 未返回最终结果。" )
315+ fallback_error_chain = MessageChain ().message (
316+ "Agent Runner did not return any result." ,
317+ )
318+ _set_runner_error_extra (event , True )
319+ event .set_result (
320+ MessageEventResult (
321+ chain = fallback_error_chain .chain or [],
322+ result_content_type = ResultContentType .AGENT_RUNNER_ERROR ,
323+ ),
324+ )
417325 yield
326+ else :
327+ is_runner_error = has_intermediate_error or final_resp .role == "err"
328+ _set_runner_error_extra (event , is_runner_error )
329+ event .set_result (
330+ MessageEventResult (
331+ chain = final_resp .result_chain .chain or [],
332+ result_content_type = (
333+ ResultContentType .AGENT_RUNNER_ERROR
334+ if is_runner_error
335+ else ResultContentType .LLM_RESULT
336+ ),
337+ ),
338+ )
339+ yield
340+ finally :
341+ if not streaming_started :
342+ await close_runner_once ()
418343
419344 asyncio .create_task (
420345 Metric .upload (
0 commit comments