Skip to content

Commit ce1a79c

Browse files
committed
refactor: simplify third-party runner aggregation and lifecycle closing
1 parent f0dc39a commit ce1a79c

1 file changed

Lines changed: 84 additions & 38 deletions

File tree

  • astrbot/core/pipeline/process_stage/method/agent_sub_stages

astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py

Lines changed: 84 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
import inspect
33
from collections.abc import AsyncGenerator
4-
from dataclasses import dataclass
54
from typing import TYPE_CHECKING
65

76
from astrbot.core import astrbot_config, logger
@@ -50,6 +49,16 @@
5049
STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC = 30
5150

5251

52+
def _coerce_positive_int(value: object, default: int) -> int:
53+
if isinstance(value, bool):
54+
return default
55+
try:
56+
coerced = int(value)
57+
except (TypeError, ValueError):
58+
return default
59+
return coerced if coerced > 0 else default
60+
61+
5362
def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None:
5463
event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error)
5564

@@ -91,7 +100,7 @@ async def run_third_party_agent(
91100
runner: "BaseAgentRunner",
92101
stream_to_general: bool = False,
93102
custom_error_message: str | None = None,
94-
) -> AsyncGenerator["_ThirdPartyRunnerOutput", None]:
103+
) -> AsyncGenerator[tuple[MessageChain, bool], None]:
95104
"""
96105
运行第三方 agent runner 并转换响应格式
97106
类似于 run_agent 函数,但专门处理第三方 agent runner
@@ -101,21 +110,12 @@ async def run_third_party_agent(
101110
if resp.type == "streaming_delta":
102111
if stream_to_general:
103112
continue
104-
yield _ThirdPartyRunnerOutput(
105-
chain=resp.data["chain"],
106-
is_error=False,
107-
)
113+
yield resp.data["chain"], False
108114
elif resp.type == "llm_result":
109115
if stream_to_general:
110-
yield _ThirdPartyRunnerOutput(
111-
chain=resp.data["chain"],
112-
is_error=False,
113-
)
116+
yield resp.data["chain"], False
114117
elif resp.type == "err":
115-
yield _ThirdPartyRunnerOutput(
116-
chain=resp.data["chain"],
117-
is_error=True,
118-
)
118+
yield resp.data["chain"], True
119119
except Exception as e:
120120
logger.error(f"Third party agent runner error: {e}")
121121
err_msg = custom_error_message
@@ -125,16 +125,26 @@ async def run_third_party_agent(
125125
f"Error Type: {type(e).__name__} (3rd party)\n"
126126
f"Error Message: {str(e)}"
127127
)
128-
yield _ThirdPartyRunnerOutput(
129-
chain=MessageChain().message(err_msg),
130-
is_error=True,
131-
)
128+
yield MessageChain().message(err_msg), True
132129

133130

134-
@dataclass
135-
class _ThirdPartyRunnerOutput:
136-
chain: MessageChain
137-
is_error: bool = False
131+
async def _consume_runner_and_aggregate(
132+
runner: "BaseAgentRunner",
133+
*,
134+
stream_to_general: bool,
135+
custom_error_message: str | None,
136+
) -> AsyncGenerator[tuple[MessageChain, bool, list, bool], None]:
137+
merged_chain: list = []
138+
has_intermediate_error = False
139+
async for chain, is_error in run_third_party_agent(
140+
runner,
141+
stream_to_general=stream_to_general,
142+
custom_error_message=custom_error_message,
143+
):
144+
merged_chain.extend(chain.chain or [])
145+
if is_error:
146+
has_intermediate_error = True
147+
yield chain, is_error, merged_chain, has_intermediate_error
138148

139149

140150
async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None:
@@ -151,8 +161,15 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None:
151161

152162

153163
class _RunnerLifecycle:
154-
def __init__(self, runner: "BaseAgentRunner") -> None:
164+
def __init__(
165+
self,
166+
runner: "BaseAgentRunner",
167+
stream_consumption_close_timeout_sec: int,
168+
) -> None:
155169
self._runner = runner
170+
self._stream_consumption_close_timeout_sec = (
171+
stream_consumption_close_timeout_sec
172+
)
156173
self._closed = False
157174
self._stream_started = False
158175
self._stream_consumed = False
@@ -196,22 +213,34 @@ async def finalize(self) -> None:
196213
if (
197214
self._idle_task
198215
and not self._idle_task.done()
199-
and (not self._stream_started or self._stream_consumed or self._closed)
216+
and (self._stream_consumed or self._closed)
200217
):
201218
self._idle_task.cancel()
202219

203-
if not self._stream_started:
220+
defer_close_to_watchdog = (
221+
self._stream_started
222+
and not self._stream_consumed
223+
and self._idle_task is not None
224+
and not self._idle_task.done()
225+
and not self._closed
226+
)
227+
if defer_close_to_watchdog:
228+
return
229+
230+
if not self._closed:
204231
await self.close_once()
205232

206233
async def _close_if_never_consumed(self) -> None:
207234
try:
208-
await asyncio.sleep(STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC)
235+
await asyncio.sleep(self._stream_consumption_close_timeout_sec)
209236
except asyncio.CancelledError:
210237
return
211238

212239
if not self._stream_consumed:
213240
logger.warning(
214-
"Third-party runner stream was never consumed; closing runner to avoid resource leak.",
241+
"Third-party runner stream was never consumed in %ss; closing runner "
242+
"to avoid resource leak.",
243+
self._stream_consumption_close_timeout_sec,
215244
)
216245
await self.close_once()
217246

@@ -230,6 +259,13 @@ async def initialize(self, ctx: PipelineContext) -> None:
230259
self.unsupported_streaming_strategy: str = settings[
231260
"unsupported_streaming_strategy"
232261
]
262+
self.stream_consumption_close_timeout_sec: int = _coerce_positive_int(
263+
settings.get(
264+
"third_party_stream_consumption_close_timeout_sec",
265+
STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC,
266+
),
267+
STREAM_CONSUMPTION_CLOSE_TIMEOUT_SEC,
268+
)
233269

234270
async def _resolve_persona_custom_error_message(
235271
self, event: AstrMessageEvent
@@ -258,20 +294,25 @@ async def _handle_streaming_response(
258294
custom_error_message: str | None,
259295
) -> AsyncGenerator[None, None]:
260296
stream_has_runner_error = False
297+
merged_chain: list = []
261298

262299
async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
263-
nonlocal stream_has_runner_error
300+
nonlocal merged_chain, stream_has_runner_error
264301
lifecycle.mark_stream_consumed()
265302
try:
266-
async for runner_output in run_third_party_agent(
303+
async for (
304+
chain,
305+
is_error,
306+
merged_chain,
307+
stream_has_runner_error,
308+
) in _consume_runner_and_aggregate(
267309
runner,
268310
stream_to_general=False,
269311
custom_error_message=custom_error_message,
270312
):
271-
if runner_output.is_error:
272-
stream_has_runner_error = True
313+
if is_error:
273314
_set_runner_error_extra(event, True)
274-
yield runner_output.chain
315+
yield chain
275316
finally:
276317
# Streaming runner cleanup must happen after consumer
277318
# finishes iterating to avoid tearing down active streams.
@@ -293,7 +334,7 @@ async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
293334
is_runner_error,
294335
_,
295336
) = _resolve_final_result(
296-
merged_chain=[],
337+
merged_chain=merged_chain,
297338
final_resp=final_resp,
298339
has_intermediate_error=stream_has_runner_error,
299340
)
@@ -315,14 +356,16 @@ async def _handle_non_streaming_response(
315356
) -> AsyncGenerator[None, None]:
316357
merged_chain: list = []
317358
has_intermediate_error = False
318-
async for output in run_third_party_agent(
359+
async for (
360+
_,
361+
_,
362+
merged_chain,
363+
has_intermediate_error,
364+
) in _consume_runner_and_aggregate(
319365
runner,
320366
stream_to_general=stream_to_general,
321367
custom_error_message=custom_error_message,
322368
):
323-
merged_chain.extend(output.chain.chain or [])
324-
if output.is_error:
325-
has_intermediate_error = True
326369
yield
327370

328371
final_resp = runner.get_final_llm_resp()
@@ -421,7 +464,10 @@ async def process(
421464
and not event.platform_meta.support_streaming_message
422465
)
423466

424-
lifecycle = _RunnerLifecycle(runner)
467+
lifecycle = _RunnerLifecycle(
468+
runner,
469+
stream_consumption_close_timeout_sec=self.stream_consumption_close_timeout_sec,
470+
)
425471

426472
try:
427473
await lifecycle.reset(

0 commit comments

Comments
 (0)