Skip to content

Commit f0dc39a

Browse files
committed
fix: preserve deerflow done hook and simplify runner lifecycle
1 parent b4f7262 commit f0dc39a

2 files changed

Lines changed: 196 additions & 134 deletions

File tree

astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ async def close(self) -> None:
146146
if isinstance(api_client, DeerFlowAPIClient) and not api_client.is_closed:
147147
await api_client.close()
148148

149+
async def _notify_agent_done_hook(self) -> None:
150+
if not self.final_llm_resp:
151+
return
152+
try:
153+
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
154+
except Exception as e:
155+
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
156+
149157
def _parse_runner_config(self, provider_config: dict) -> _RunnerConfig:
150158
api_base = provider_config.get("deerflow_api_base", "http://127.0.0.1:2026")
151159
if not isinstance(api_base, str) or not api_base.startswith(
@@ -295,6 +303,7 @@ async def step(self):
295303
completion_text=f"DeerFlow request failed: {err_msg}",
296304
result_chain=err_chain,
297305
)
306+
await self._notify_agent_done_hook()
298307
yield AgentResponse(
299308
type="err",
300309
data=AgentResponseData(
@@ -778,11 +787,7 @@ async def _execute_deerflow_request(self):
778787

779788
self.final_llm_resp = LLMResponse(role=role, result_chain=final_chain)
780789
self._transition_state(AgentState.DONE)
781-
782-
try:
783-
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
784-
except Exception as e:
785-
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
790+
await self._notify_agent_done_hook()
786791

787792
yield AgentResponse(
788793
type="llm_result",

astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py

Lines changed: 186 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,72 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None:
150150
logger.warning(f"Failed to close third-party runner cleanly: {e}")
151151

152152

153+
class _RunnerLifecycle:
154+
def __init__(self, runner: "BaseAgentRunner") -> None:
155+
self._runner = runner
156+
self._closed = False
157+
self._stream_started = False
158+
self._stream_consumed = False
159+
self._idle_task: asyncio.Task[None] | None = None
160+
161+
async def reset(
162+
self,
163+
*,
164+
req: ProviderRequest,
165+
astr_agent_ctx: AstrAgentContext,
166+
provider_cfg: dict,
167+
streaming: bool,
168+
) -> None:
169+
await self._runner.reset(
170+
request=req,
171+
run_context=AgentContextWrapper(
172+
context=astr_agent_ctx,
173+
tool_call_timeout=60,
174+
),
175+
agent_hooks=MAIN_AGENT_HOOKS,
176+
provider_config=provider_cfg,
177+
streaming=streaming,
178+
)
179+
180+
async def close_once(self) -> None:
181+
if self._closed:
182+
return
183+
self._closed = True
184+
await _close_runner_if_supported(self._runner)
185+
186+
def mark_stream_started(self) -> None:
187+
self._stream_started = True
188+
self._idle_task = asyncio.create_task(self._close_if_never_consumed())
189+
190+
def mark_stream_consumed(self) -> None:
191+
self._stream_consumed = True
192+
if self._idle_task and not self._idle_task.done():
193+
self._idle_task.cancel()
194+
195+
async def finalize(self) -> None:
196+
if (
197+
self._idle_task
198+
and not self._idle_task.done()
199+
and (not self._stream_started or self._stream_consumed or self._closed)
200+
):
201+
self._idle_task.cancel()
202+
203+
if not self._stream_started:
204+
await self.close_once()
205+
206+
async def _close_if_never_consumed(self) -> None:
207+
try:
208+
await asyncio.sleep(STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC)
209+
except asyncio.CancelledError:
210+
return
211+
212+
if not self._stream_consumed:
213+
logger.warning(
214+
"Third-party runner stream was never consumed; closing runner to avoid resource leak.",
215+
)
216+
await self.close_once()
217+
218+
153219
class ThirdPartyAgentSubStage(Stage):
154220
async def initialize(self, ctx: PipelineContext) -> None:
155221
self.ctx = ctx
@@ -183,6 +249,109 @@ async def _resolve_persona_custom_error_message(
183249
logger.debug("Failed to resolve persona custom error message: %s", e)
184250
return None
185251

252+
async def _handle_streaming_response(
253+
self,
254+
*,
255+
lifecycle: _RunnerLifecycle,
256+
runner: "BaseAgentRunner",
257+
event: AstrMessageEvent,
258+
custom_error_message: str | None,
259+
) -> AsyncGenerator[None, None]:
260+
stream_has_runner_error = False
261+
262+
async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
263+
nonlocal stream_has_runner_error
264+
lifecycle.mark_stream_consumed()
265+
try:
266+
async for runner_output in run_third_party_agent(
267+
runner,
268+
stream_to_general=False,
269+
custom_error_message=custom_error_message,
270+
):
271+
if runner_output.is_error:
272+
stream_has_runner_error = True
273+
_set_runner_error_extra(event, True)
274+
yield runner_output.chain
275+
finally:
276+
# Streaming runner cleanup must happen after consumer
277+
# finishes iterating to avoid tearing down active streams.
278+
await lifecycle.close_once()
279+
280+
event.set_result(
281+
MessageEventResult()
282+
.set_result_content_type(ResultContentType.STREAMING_RESULT)
283+
.set_async_stream(_stream_runner_chain()),
284+
)
285+
lifecycle.mark_stream_started()
286+
yield
287+
288+
if runner.done():
289+
final_resp = runner.get_final_llm_resp()
290+
if final_resp and final_resp.result_chain:
291+
(
292+
final_chain,
293+
is_runner_error,
294+
_,
295+
) = _resolve_final_result(
296+
merged_chain=[],
297+
final_resp=final_resp,
298+
has_intermediate_error=stream_has_runner_error,
299+
)
300+
_set_runner_error_extra(event, is_runner_error)
301+
event.set_result(
302+
MessageEventResult(
303+
chain=final_chain,
304+
result_content_type=ResultContentType.STREAMING_FINISH,
305+
),
306+
)
307+
308+
async def _handle_non_streaming_response(
309+
self,
310+
*,
311+
runner: "BaseAgentRunner",
312+
event: AstrMessageEvent,
313+
stream_to_general: bool,
314+
custom_error_message: str | None,
315+
) -> AsyncGenerator[None, None]:
316+
merged_chain: list = []
317+
has_intermediate_error = False
318+
async for output in run_third_party_agent(
319+
runner,
320+
stream_to_general=stream_to_general,
321+
custom_error_message=custom_error_message,
322+
):
323+
merged_chain.extend(output.chain.chain or [])
324+
if output.is_error:
325+
has_intermediate_error = True
326+
yield
327+
328+
final_resp = runner.get_final_llm_resp()
329+
if not final_resp or not final_resp.result_chain:
330+
if merged_chain:
331+
logger.warning(
332+
"Agent Runner returned no final response, fallback to streamed error/result chain."
333+
)
334+
else:
335+
logger.warning("Agent Runner 未返回最终结果。")
336+
337+
(
338+
final_chain,
339+
is_runner_error,
340+
result_content_type,
341+
) = _resolve_final_result(
342+
merged_chain=merged_chain,
343+
final_resp=final_resp,
344+
has_intermediate_error=has_intermediate_error,
345+
)
346+
_set_runner_error_extra(event, is_runner_error)
347+
event.set_result(
348+
MessageEventResult(
349+
chain=final_chain,
350+
result_content_type=result_content_type,
351+
),
352+
)
353+
yield
354+
186355
async def process(
187356
self, event: AstrMessageEvent, provider_wake_prefix: str
188357
) -> AsyncGenerator[None, None]:
@@ -252,145 +421,33 @@ async def process(
252421
and not event.platform_meta.support_streaming_message
253422
)
254423

255-
runner_closed = False
256-
streaming_started = False
257-
stream_consumption_started = False
258-
stream_idle_close_task: asyncio.Task[None] | None = None
259-
260-
async def close_runner_once() -> None:
261-
nonlocal runner_closed
262-
if runner_closed:
263-
return
264-
runner_closed = True
265-
await _close_runner_if_supported(runner)
266-
267-
async def close_if_stream_never_consumed() -> None:
268-
try:
269-
await asyncio.sleep(STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC)
270-
except asyncio.CancelledError:
271-
return
272-
if not stream_consumption_started:
273-
logger.warning(
274-
"Third-party runner stream was never consumed; closing runner to avoid resource leak.",
275-
)
276-
await close_runner_once()
424+
lifecycle = _RunnerLifecycle(runner)
277425

278426
try:
279-
await runner.reset(
280-
request=req,
281-
run_context=AgentContextWrapper(
282-
context=astr_agent_ctx,
283-
tool_call_timeout=60,
284-
),
285-
agent_hooks=MAIN_AGENT_HOOKS,
286-
provider_config=self.prov_cfg,
427+
await lifecycle.reset(
428+
req=req,
429+
astr_agent_ctx=astr_agent_ctx,
430+
provider_cfg=self.prov_cfg,
287431
streaming=streaming_response,
288432
)
289-
290433
if streaming_response and not stream_to_general:
291-
stream_has_runner_error = False
292-
293-
async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
294-
nonlocal stream_has_runner_error, stream_consumption_started
295-
stream_consumption_started = True
296-
if stream_idle_close_task and not stream_idle_close_task.done():
297-
stream_idle_close_task.cancel()
298-
try:
299-
async for runner_output in run_third_party_agent(
300-
runner,
301-
stream_to_general=False,
302-
custom_error_message=custom_error_message,
303-
):
304-
if runner_output.is_error:
305-
stream_has_runner_error = True
306-
_set_runner_error_extra(event, True)
307-
yield runner_output.chain
308-
finally:
309-
# Streaming runner cleanup must happen after consumer
310-
# finishes iterating to avoid tearing down active streams.
311-
await close_runner_once()
312-
313-
event.set_result(
314-
MessageEventResult()
315-
.set_result_content_type(ResultContentType.STREAMING_RESULT)
316-
.set_async_stream(_stream_runner_chain()),
317-
)
318-
stream_idle_close_task = asyncio.create_task(
319-
close_if_stream_never_consumed(),
320-
)
321-
streaming_started = True
322-
yield
323-
324-
if runner.done():
325-
final_resp = runner.get_final_llm_resp()
326-
if final_resp and final_resp.result_chain:
327-
(
328-
final_chain,
329-
is_runner_error,
330-
_,
331-
) = _resolve_final_result(
332-
merged_chain=[],
333-
final_resp=final_resp,
334-
has_intermediate_error=stream_has_runner_error,
335-
)
336-
_set_runner_error_extra(event, is_runner_error)
337-
event.set_result(
338-
MessageEventResult(
339-
chain=final_chain,
340-
result_content_type=ResultContentType.STREAMING_FINISH,
341-
),
342-
)
434+
async for _ in self._handle_streaming_response(
435+
lifecycle=lifecycle,
436+
runner=runner,
437+
event=event,
438+
custom_error_message=custom_error_message,
439+
):
440+
yield
343441
else:
344-
output_stream = run_third_party_agent(
345-
runner,
442+
async for _ in self._handle_non_streaming_response(
443+
runner=runner,
444+
event=event,
346445
stream_to_general=stream_to_general,
347446
custom_error_message=custom_error_message,
348-
)
349-
merged_chain: list = []
350-
has_intermediate_error = False
351-
async for output in output_stream:
352-
merged_chain.extend(output.chain.chain or [])
353-
if output.is_error:
354-
has_intermediate_error = True
447+
):
355448
yield
356-
357-
final_resp = runner.get_final_llm_resp()
358-
if not final_resp or not final_resp.result_chain:
359-
if merged_chain:
360-
logger.warning(
361-
"Agent Runner returned no final response, fallback to streamed error/result chain."
362-
)
363-
else:
364-
logger.warning("Agent Runner 未返回最终结果。")
365-
366-
(
367-
final_chain,
368-
is_runner_error,
369-
result_content_type,
370-
) = _resolve_final_result(
371-
merged_chain=merged_chain,
372-
final_resp=final_resp,
373-
has_intermediate_error=has_intermediate_error,
374-
)
375-
_set_runner_error_extra(event, is_runner_error)
376-
event.set_result(
377-
MessageEventResult(
378-
chain=final_chain,
379-
result_content_type=result_content_type,
380-
),
381-
)
382-
yield
383449
finally:
384-
if (
385-
stream_idle_close_task
386-
and not stream_idle_close_task.done()
387-
and (
388-
not streaming_started or stream_consumption_started or runner_closed
389-
)
390-
):
391-
stream_idle_close_task.cancel()
392-
if not streaming_started:
393-
await close_runner_once()
450+
await lifecycle.finalize()
394451

395452
asyncio.create_task(
396453
Metric.upload(

0 commit comments

Comments
 (0)