Skip to content

Commit 3ae43a2

Browse files
committed
fix: defer streaming runner cleanup and unify error mapping
1 parent 352ab50 commit 3ae43a2

1 file changed

Lines changed: 64 additions & 42 deletions

File tree

  • astrbot/core/pipeline/process_stage/method/agent_sub_stages

astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py

Lines changed: 64 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,32 @@
4848
THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY = "_third_party_runner_error"
4949

5050

51+
def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None:
52+
event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error)
53+
54+
55+
def _runner_result_content_type(is_error: bool) -> ResultContentType:
56+
return (
57+
ResultContentType.AGENT_RUNNER_ERROR
58+
if is_error
59+
else ResultContentType.LLM_RESULT
60+
)
61+
62+
63+
def _set_non_stream_runner_result(
64+
event: "AstrMessageEvent",
65+
chain: list,
66+
is_error: bool,
67+
) -> None:
68+
_set_runner_error_extra(event, is_error)
69+
event.set_result(
70+
MessageEventResult(
71+
chain=chain,
72+
result_content_type=_runner_result_content_type(is_error),
73+
),
74+
)
75+
76+
5177
async def run_third_party_agent(
5278
runner: "BaseAgentRunner",
5379
stream_to_general: bool = False,
@@ -213,6 +239,16 @@ async def process(
213239
and not event.platform_meta.support_streaming_message
214240
)
215241

242+
runner_closed = False
243+
defer_runner_close_to_stream = False
244+
245+
async def _close_runner_once() -> None:
246+
nonlocal runner_closed
247+
if runner_closed:
248+
return
249+
runner_closed = True
250+
await _close_runner_if_supported(runner)
251+
216252
try:
217253
await runner.reset(
218254
request=req,
@@ -231,32 +267,35 @@ async def process(
231267

232268
async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
233269
nonlocal stream_has_runner_error
234-
async for runner_output in run_third_party_agent(
235-
runner,
236-
stream_to_general=False,
237-
custom_error_message=custom_error_message,
238-
):
239-
if runner_output.is_error:
240-
stream_has_runner_error = True
241-
event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, True)
242-
yield runner_output.chain
270+
try:
271+
async for runner_output in run_third_party_agent(
272+
runner,
273+
stream_to_general=False,
274+
custom_error_message=custom_error_message,
275+
):
276+
if runner_output.is_error:
277+
stream_has_runner_error = True
278+
_set_runner_error_extra(event, True)
279+
yield runner_output.chain
280+
finally:
281+
# Streaming runner cleanup must happen after consumer
282+
# finishes iterating to avoid tearing down active streams.
283+
await _close_runner_once()
243284

244285
event.set_result(
245286
MessageEventResult()
246287
.set_result_content_type(ResultContentType.STREAMING_RESULT)
247288
.set_async_stream(_stream_runner_chain()),
248289
)
290+
defer_runner_close_to_stream = True
249291
yield
250292
if runner.done():
251293
final_resp = runner.get_final_llm_resp()
252294
if final_resp and final_resp.result_chain:
253295
is_runner_error = (
254296
stream_has_runner_error or final_resp.role == "err"
255297
)
256-
event.set_extra(
257-
THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY,
258-
is_runner_error,
259-
)
298+
_set_runner_error_extra(event, is_runner_error)
260299
event.set_result(
261300
MessageEventResult(
262301
chain=final_resp.result_chain.chain or [],
@@ -284,44 +323,27 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
284323
logger.warning(
285324
"Agent Runner returned no final response, fallback to streamed error/result chain."
286325
)
287-
content_type = (
288-
ResultContentType.AGENT_RUNNER_ERROR
289-
if fallback_is_error
290-
else ResultContentType.LLM_RESULT
291-
)
292-
event.set_extra(
293-
THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY,
294-
fallback_is_error,
295-
)
296-
event.set_result(
297-
MessageEventResult(
298-
chain=merged_chain,
299-
result_content_type=content_type,
300-
),
326+
_set_non_stream_runner_result(
327+
event=event,
328+
chain=merged_chain,
329+
is_error=fallback_is_error,
301330
)
302331
yield
303332
return
304333
logger.warning("Agent Runner 未返回最终结果。")
305334
return
306335

307-
content_type = (
308-
ResultContentType.AGENT_RUNNER_ERROR
309-
if final_resp.role == "err"
310-
else ResultContentType.LLM_RESULT
311-
)
312-
event.set_extra(
313-
THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY,
314-
final_resp.role == "err",
315-
)
316-
event.set_result(
317-
MessageEventResult(
318-
chain=final_resp.result_chain.chain or [],
319-
result_content_type=content_type,
320-
),
336+
# Preserve intermediate error signals even if final role is assistant.
337+
is_runner_error = fallback_is_error or final_resp.role == "err"
338+
_set_non_stream_runner_result(
339+
event=event,
340+
chain=final_resp.result_chain.chain or [],
341+
is_error=is_runner_error,
321342
)
322343
yield
323344
finally:
324-
await _close_runner_if_supported(runner)
345+
if not defer_runner_close_to_stream:
346+
await _close_runner_once()
325347

326348
asyncio.create_task(
327349
Metric.upload(

0 commit comments

Comments
 (0)