@@ -150,6 +150,72 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None:
150150 logger .warning (f"Failed to close third-party runner cleanly: { e } " )
151151
152152
153+ class _RunnerLifecycle :
154+ def __init__ (self , runner : "BaseAgentRunner" ) -> None :
155+ self ._runner = runner
156+ self ._closed = False
157+ self ._stream_started = False
158+ self ._stream_consumed = False
159+ self ._idle_task : asyncio .Task [None ] | None = None
160+
161+ async def reset (
162+ self ,
163+ * ,
164+ req : ProviderRequest ,
165+ astr_agent_ctx : AstrAgentContext ,
166+ provider_cfg : dict ,
167+ streaming : bool ,
168+ ) -> None :
169+ await self ._runner .reset (
170+ request = req ,
171+ run_context = AgentContextWrapper (
172+ context = astr_agent_ctx ,
173+ tool_call_timeout = 60 ,
174+ ),
175+ agent_hooks = MAIN_AGENT_HOOKS ,
176+ provider_config = provider_cfg ,
177+ streaming = streaming ,
178+ )
179+
180+ async def close_once (self ) -> None :
181+ if self ._closed :
182+ return
183+ self ._closed = True
184+ await _close_runner_if_supported (self ._runner )
185+
186+ def mark_stream_started (self ) -> None :
187+ self ._stream_started = True
188+ self ._idle_task = asyncio .create_task (self ._close_if_never_consumed ())
189+
190+ def mark_stream_consumed (self ) -> None :
191+ self ._stream_consumed = True
192+ if self ._idle_task and not self ._idle_task .done ():
193+ self ._idle_task .cancel ()
194+
195+ async def finalize (self ) -> None :
196+ if (
197+ self ._idle_task
198+ and not self ._idle_task .done ()
199+ and (not self ._stream_started or self ._stream_consumed or self ._closed )
200+ ):
201+ self ._idle_task .cancel ()
202+
203+ if not self ._stream_started :
204+ await self .close_once ()
205+
206+ async def _close_if_never_consumed (self ) -> None :
207+ try :
208+ await asyncio .sleep (STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC )
209+ except asyncio .CancelledError :
210+ return
211+
212+ if not self ._stream_consumed :
213+ logger .warning (
214+ "Third-party runner stream was never consumed; closing runner to avoid resource leak." ,
215+ )
216+ await self .close_once ()
217+
218+
153219class ThirdPartyAgentSubStage (Stage ):
154220 async def initialize (self , ctx : PipelineContext ) -> None :
155221 self .ctx = ctx
@@ -183,6 +249,109 @@ async def _resolve_persona_custom_error_message(
183249 logger .debug ("Failed to resolve persona custom error message: %s" , e )
184250 return None
185251
252+ async def _handle_streaming_response (
253+ self ,
254+ * ,
255+ lifecycle : _RunnerLifecycle ,
256+ runner : "BaseAgentRunner" ,
257+ event : AstrMessageEvent ,
258+ custom_error_message : str | None ,
259+ ) -> AsyncGenerator [None , None ]:
260+ stream_has_runner_error = False
261+
262+ async def _stream_runner_chain () -> AsyncGenerator [MessageChain , None ]:
263+ nonlocal stream_has_runner_error
264+ lifecycle .mark_stream_consumed ()
265+ try :
266+ async for runner_output in run_third_party_agent (
267+ runner ,
268+ stream_to_general = False ,
269+ custom_error_message = custom_error_message ,
270+ ):
271+ if runner_output .is_error :
272+ stream_has_runner_error = True
273+ _set_runner_error_extra (event , True )
274+ yield runner_output .chain
275+ finally :
276+ # Streaming runner cleanup must happen after consumer
277+ # finishes iterating to avoid tearing down active streams.
278+ await lifecycle .close_once ()
279+
280+ event .set_result (
281+ MessageEventResult ()
282+ .set_result_content_type (ResultContentType .STREAMING_RESULT )
283+ .set_async_stream (_stream_runner_chain ()),
284+ )
285+ lifecycle .mark_stream_started ()
286+ yield
287+
288+ if runner .done ():
289+ final_resp = runner .get_final_llm_resp ()
290+ if final_resp and final_resp .result_chain :
291+ (
292+ final_chain ,
293+ is_runner_error ,
294+ _ ,
295+ ) = _resolve_final_result (
296+ merged_chain = [],
297+ final_resp = final_resp ,
298+ has_intermediate_error = stream_has_runner_error ,
299+ )
300+ _set_runner_error_extra (event , is_runner_error )
301+ event .set_result (
302+ MessageEventResult (
303+ chain = final_chain ,
304+ result_content_type = ResultContentType .STREAMING_FINISH ,
305+ ),
306+ )
307+
308+ async def _handle_non_streaming_response (
309+ self ,
310+ * ,
311+ runner : "BaseAgentRunner" ,
312+ event : AstrMessageEvent ,
313+ stream_to_general : bool ,
314+ custom_error_message : str | None ,
315+ ) -> AsyncGenerator [None , None ]:
316+ merged_chain : list = []
317+ has_intermediate_error = False
318+ async for output in run_third_party_agent (
319+ runner ,
320+ stream_to_general = stream_to_general ,
321+ custom_error_message = custom_error_message ,
322+ ):
323+ merged_chain .extend (output .chain .chain or [])
324+ if output .is_error :
325+ has_intermediate_error = True
326+ yield
327+
328+ final_resp = runner .get_final_llm_resp ()
329+ if not final_resp or not final_resp .result_chain :
330+ if merged_chain :
331+ logger .warning (
332+ "Agent Runner returned no final response, fallback to streamed error/result chain."
333+ )
334+ else :
335+ logger .warning ("Agent Runner 未返回最终结果。" )
336+
337+ (
338+ final_chain ,
339+ is_runner_error ,
340+ result_content_type ,
341+ ) = _resolve_final_result (
342+ merged_chain = merged_chain ,
343+ final_resp = final_resp ,
344+ has_intermediate_error = has_intermediate_error ,
345+ )
346+ _set_runner_error_extra (event , is_runner_error )
347+ event .set_result (
348+ MessageEventResult (
349+ chain = final_chain ,
350+ result_content_type = result_content_type ,
351+ ),
352+ )
353+ yield
354+
186355 async def process (
187356 self , event : AstrMessageEvent , provider_wake_prefix : str
188357 ) -> AsyncGenerator [None , None ]:
@@ -252,145 +421,33 @@ async def process(
252421 and not event .platform_meta .support_streaming_message
253422 )
254423
255- runner_closed = False
256- streaming_started = False
257- stream_consumption_started = False
258- stream_idle_close_task : asyncio .Task [None ] | None = None
259-
260- async def close_runner_once () -> None :
261- nonlocal runner_closed
262- if runner_closed :
263- return
264- runner_closed = True
265- await _close_runner_if_supported (runner )
266-
267- async def close_if_stream_never_consumed () -> None :
268- try :
269- await asyncio .sleep (STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC )
270- except asyncio .CancelledError :
271- return
272- if not stream_consumption_started :
273- logger .warning (
274- "Third-party runner stream was never consumed; closing runner to avoid resource leak." ,
275- )
276- await close_runner_once ()
424+ lifecycle = _RunnerLifecycle (runner )
277425
278426 try :
279- await runner .reset (
280- request = req ,
281- run_context = AgentContextWrapper (
282- context = astr_agent_ctx ,
283- tool_call_timeout = 60 ,
284- ),
285- agent_hooks = MAIN_AGENT_HOOKS ,
286- provider_config = self .prov_cfg ,
427+ await lifecycle .reset (
428+ req = req ,
429+ astr_agent_ctx = astr_agent_ctx ,
430+ provider_cfg = self .prov_cfg ,
287431 streaming = streaming_response ,
288432 )
289-
290433 if streaming_response and not stream_to_general :
291- stream_has_runner_error = False
292-
293- async def _stream_runner_chain () -> AsyncGenerator [MessageChain , None ]:
294- nonlocal stream_has_runner_error , stream_consumption_started
295- stream_consumption_started = True
296- if stream_idle_close_task and not stream_idle_close_task .done ():
297- stream_idle_close_task .cancel ()
298- try :
299- async for runner_output in run_third_party_agent (
300- runner ,
301- stream_to_general = False ,
302- custom_error_message = custom_error_message ,
303- ):
304- if runner_output .is_error :
305- stream_has_runner_error = True
306- _set_runner_error_extra (event , True )
307- yield runner_output .chain
308- finally :
309- # Streaming runner cleanup must happen after consumer
310- # finishes iterating to avoid tearing down active streams.
311- await close_runner_once ()
312-
313- event .set_result (
314- MessageEventResult ()
315- .set_result_content_type (ResultContentType .STREAMING_RESULT )
316- .set_async_stream (_stream_runner_chain ()),
317- )
318- stream_idle_close_task = asyncio .create_task (
319- close_if_stream_never_consumed (),
320- )
321- streaming_started = True
322- yield
323-
324- if runner .done ():
325- final_resp = runner .get_final_llm_resp ()
326- if final_resp and final_resp .result_chain :
327- (
328- final_chain ,
329- is_runner_error ,
330- _ ,
331- ) = _resolve_final_result (
332- merged_chain = [],
333- final_resp = final_resp ,
334- has_intermediate_error = stream_has_runner_error ,
335- )
336- _set_runner_error_extra (event , is_runner_error )
337- event .set_result (
338- MessageEventResult (
339- chain = final_chain ,
340- result_content_type = ResultContentType .STREAMING_FINISH ,
341- ),
342- )
434+ async for _ in self ._handle_streaming_response (
435+ lifecycle = lifecycle ,
436+ runner = runner ,
437+ event = event ,
438+ custom_error_message = custom_error_message ,
439+ ):
440+ yield
343441 else :
344- output_stream = run_third_party_agent (
345- runner ,
442+ async for _ in self ._handle_non_streaming_response (
443+ runner = runner ,
444+ event = event ,
346445 stream_to_general = stream_to_general ,
347446 custom_error_message = custom_error_message ,
348- )
349- merged_chain : list = []
350- has_intermediate_error = False
351- async for output in output_stream :
352- merged_chain .extend (output .chain .chain or [])
353- if output .is_error :
354- has_intermediate_error = True
447+ ):
355448 yield
356-
357- final_resp = runner .get_final_llm_resp ()
358- if not final_resp or not final_resp .result_chain :
359- if merged_chain :
360- logger .warning (
361- "Agent Runner returned no final response, fallback to streamed error/result chain."
362- )
363- else :
364- logger .warning ("Agent Runner 未返回最终结果。" )
365-
366- (
367- final_chain ,
368- is_runner_error ,
369- result_content_type ,
370- ) = _resolve_final_result (
371- merged_chain = merged_chain ,
372- final_resp = final_resp ,
373- has_intermediate_error = has_intermediate_error ,
374- )
375- _set_runner_error_extra (event , is_runner_error )
376- event .set_result (
377- MessageEventResult (
378- chain = final_chain ,
379- result_content_type = result_content_type ,
380- ),
381- )
382- yield
383449 finally :
384- if (
385- stream_idle_close_task
386- and not stream_idle_close_task .done ()
387- and (
388- not streaming_started or stream_consumption_started or runner_closed
389- )
390- ):
391- stream_idle_close_task .cancel ()
392- if not streaming_started :
393- await close_runner_once ()
450+ await lifecycle .finalize ()
394451
395452 asyncio .create_task (
396453 Metric .upload (
0 commit comments