Skip to content

Commit 98a4840

Browse files
committed
fix: tighten deerflow image validation and runner lifecycle
1 parent 613559a commit 98a4840

2 files changed

Lines changed: 128 additions & 195 deletions

File tree

astrbot/core/agent/runners/deerflow/deerflow_agent_runner.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import base64
23
import hashlib
34
import json
45
import sys
@@ -330,13 +331,19 @@ def _is_likely_base64_image(self, value: str) -> bool:
330331
return False
331332

332333
compact = value.replace("\n", "").replace("\r", "")
333-
if not compact or len(compact) % 4 != 0:
334+
if not compact or len(compact) < 32 or len(compact) % 4 != 0:
334335
return False
335336

336337
base64_chars = (
337338
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="
338339
)
339-
return all(ch in base64_chars for ch in compact)
340+
if any(ch not in base64_chars for ch in compact):
341+
return False
342+
try:
343+
base64.b64decode(compact, validate=True)
344+
except Exception:
345+
return False
346+
return True
340347

341348
def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any:
342349
if not image_urls:
@@ -360,10 +367,11 @@ def _build_user_content(self, prompt: str, image_urls: list[str]) -> T.Any:
360367
if not self._is_likely_base64_image(url):
361368
skipped_invalid_images += 1
362369
continue
370+
compact_base64 = url.replace("\n", "").replace("\r", "")
363371
content.append(
364372
{
365373
"type": "image_url",
366-
"image_url": {"url": f"data:image/png;base64,{url}"},
374+
"image_url": {"url": f"data:image/png;base64,{compact_base64}"},
367375
},
368376
)
369377
if skipped_invalid_images:

astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py

Lines changed: 117 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import asyncio
22
import inspect
3-
import typing as T
43
from collections.abc import AsyncGenerator
5-
from contextlib import asynccontextmanager
64
from dataclasses import dataclass
75
from typing import TYPE_CHECKING
86

@@ -30,7 +28,6 @@
3028

3129
if TYPE_CHECKING:
3230
from astrbot.core.agent.runners.base import BaseAgentRunner
33-
from astrbot.core.provider.entities import LLMResponse
3431
from astrbot.core.pipeline.stage import Stage
3532
from astrbot.core.platform.astr_message_event import AstrMessageEvent
3633
from astrbot.core.provider.entities import (
@@ -55,37 +52,6 @@ def _set_runner_error_extra(event: "AstrMessageEvent", is_error: bool) -> None:
5552
event.set_extra(THIRD_PARTY_RUNNER_ERROR_EXTRA_KEY, is_error)
5653

5754

58-
def _runner_result_content_type(is_error: bool) -> ResultContentType:
59-
return (
60-
ResultContentType.AGENT_RUNNER_ERROR
61-
if is_error
62-
else ResultContentType.LLM_RESULT
63-
)
64-
65-
66-
def _set_non_stream_runner_result(
67-
event: "AstrMessageEvent",
68-
chain: list,
69-
is_error: bool,
70-
) -> None:
71-
_set_runner_error_extra(event, is_error)
72-
event.set_result(
73-
MessageEventResult(
74-
chain=chain,
75-
result_content_type=_runner_result_content_type(is_error),
76-
),
77-
)
78-
79-
80-
def _aggregate_runner_error(
81-
has_intermediate_error: bool,
82-
final_resp: "LLMResponse | None",
83-
) -> bool:
84-
if not final_resp:
85-
return has_intermediate_error
86-
return has_intermediate_error or final_resp.role == "err"
87-
88-
8955
async def run_third_party_agent(
9056
runner: "BaseAgentRunner",
9157
stream_to_general: bool = False,
@@ -149,44 +115,6 @@ async def _close_runner_if_supported(runner: "BaseAgentRunner") -> None:
149115
logger.warning(f"Failed to close third-party runner cleanly: {e}")
150116

151117

152-
@asynccontextmanager
153-
async def _runner_session(
154-
runner: "BaseAgentRunner",
155-
*,
156-
request: ProviderRequest,
157-
run_context: AgentContextWrapper,
158-
agent_hooks: T.Any,
159-
provider_config: dict,
160-
streaming: bool,
161-
):
162-
runner_closed = False
163-
defer_close = False
164-
165-
async def close_runner_once() -> None:
166-
nonlocal runner_closed
167-
if runner_closed:
168-
return
169-
runner_closed = True
170-
await _close_runner_if_supported(runner)
171-
172-
def defer_runner_close() -> None:
173-
nonlocal defer_close
174-
defer_close = True
175-
176-
await runner.reset(
177-
request=request,
178-
run_context=run_context,
179-
agent_hooks=agent_hooks,
180-
provider_config=provider_config,
181-
streaming=streaming,
182-
)
183-
try:
184-
yield close_runner_once, defer_runner_close
185-
finally:
186-
if not defer_close:
187-
await close_runner_once()
188-
189-
190118
class ThirdPartyAgentSubStage(Stage):
191119
async def initialize(self, ctx: PipelineContext) -> None:
192120
self.ctx = ctx
@@ -220,101 +148,6 @@ async def _resolve_persona_custom_error_message(
220148
logger.debug("Failed to resolve persona custom error message: %s", e)
221149
return None
222150

223-
async def _handle_streaming_runner(
224-
self,
225-
runner: "BaseAgentRunner",
226-
event: AstrMessageEvent,
227-
custom_error_message: str | None,
228-
close_runner_once: T.Callable[[], T.Awaitable[None]],
229-
) -> AsyncGenerator[None, None]:
230-
stream_has_runner_error = False
231-
232-
async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
233-
nonlocal stream_has_runner_error
234-
try:
235-
async for runner_output in run_third_party_agent(
236-
runner,
237-
stream_to_general=False,
238-
custom_error_message=custom_error_message,
239-
):
240-
if runner_output.is_error:
241-
stream_has_runner_error = True
242-
_set_runner_error_extra(event, True)
243-
yield runner_output.chain
244-
finally:
245-
# Streaming runner cleanup must happen after consumer
246-
# finishes iterating to avoid tearing down active streams.
247-
await close_runner_once()
248-
249-
event.set_result(
250-
MessageEventResult()
251-
.set_result_content_type(ResultContentType.STREAMING_RESULT)
252-
.set_async_stream(_stream_runner_chain()),
253-
)
254-
yield
255-
256-
if runner.done():
257-
final_resp = runner.get_final_llm_resp()
258-
if final_resp and final_resp.result_chain:
259-
is_runner_error = _aggregate_runner_error(
260-
has_intermediate_error=stream_has_runner_error,
261-
final_resp=final_resp,
262-
)
263-
_set_runner_error_extra(event, is_runner_error)
264-
event.set_result(
265-
MessageEventResult(
266-
chain=final_resp.result_chain.chain or [],
267-
result_content_type=ResultContentType.STREAMING_FINISH,
268-
),
269-
)
270-
271-
async def _handle_non_streaming_runner(
272-
self,
273-
runner: "BaseAgentRunner",
274-
event: AstrMessageEvent,
275-
stream_to_general: bool,
276-
custom_error_message: str | None,
277-
) -> AsyncGenerator[None, None]:
278-
merged_chain: list = []
279-
has_intermediate_error = False
280-
async for output in run_third_party_agent(
281-
runner,
282-
stream_to_general=stream_to_general,
283-
custom_error_message=custom_error_message,
284-
):
285-
merged_chain.extend(output.chain.chain or [])
286-
if output.is_error:
287-
has_intermediate_error = True
288-
yield
289-
290-
final_resp = runner.get_final_llm_resp()
291-
292-
if not final_resp or not final_resp.result_chain:
293-
if merged_chain:
294-
logger.warning(
295-
"Agent Runner returned no final response, fallback to streamed error/result chain."
296-
)
297-
_set_non_stream_runner_result(
298-
event=event,
299-
chain=merged_chain,
300-
is_error=has_intermediate_error,
301-
)
302-
yield
303-
return
304-
logger.warning("Agent Runner 未返回最终结果。")
305-
return
306-
307-
is_runner_error = _aggregate_runner_error(
308-
has_intermediate_error=has_intermediate_error,
309-
final_resp=final_resp,
310-
)
311-
_set_non_stream_runner_result(
312-
event=event,
313-
chain=final_resp.result_chain.chain or [],
314-
is_error=is_runner_error,
315-
)
316-
yield
317-
318151
async def process(
319152
self, event: AstrMessageEvent, provider_wake_prefix: str
320153
) -> AsyncGenerator[None, None]:
@@ -384,37 +217,129 @@ async def process(
384217
and not event.platform_meta.support_streaming_message
385218
)
386219

387-
async with _runner_session(
388-
runner=runner,
389-
request=req,
390-
run_context=AgentContextWrapper(
391-
context=astr_agent_ctx,
392-
tool_call_timeout=60,
393-
),
394-
agent_hooks=MAIN_AGENT_HOOKS,
395-
provider_config=self.prov_cfg,
396-
streaming=streaming_response,
397-
) as (close_runner_once, defer_runner_close):
220+
runner_closed = False
221+
streaming_started = False
222+
223+
async def close_runner_once() -> None:
224+
nonlocal runner_closed
225+
if runner_closed:
226+
return
227+
runner_closed = True
228+
await _close_runner_if_supported(runner)
229+
230+
try:
231+
await runner.reset(
232+
request=req,
233+
run_context=AgentContextWrapper(
234+
context=astr_agent_ctx,
235+
tool_call_timeout=60,
236+
),
237+
agent_hooks=MAIN_AGENT_HOOKS,
238+
provider_config=self.prov_cfg,
239+
streaming=streaming_response,
240+
)
241+
398242
if streaming_response and not stream_to_general:
399-
stream_started = False
400-
async for _ in self._handle_streaming_runner(
401-
runner=runner,
402-
event=event,
403-
custom_error_message=custom_error_message,
404-
close_runner_once=close_runner_once,
405-
):
406-
if not stream_started:
407-
defer_runner_close()
408-
stream_started = True
409-
yield
243+
stream_has_runner_error = False
244+
245+
async def _stream_runner_chain() -> AsyncGenerator[MessageChain, None]:
246+
nonlocal stream_has_runner_error
247+
try:
248+
async for runner_output in run_third_party_agent(
249+
runner,
250+
stream_to_general=False,
251+
custom_error_message=custom_error_message,
252+
):
253+
if runner_output.is_error:
254+
stream_has_runner_error = True
255+
_set_runner_error_extra(event, True)
256+
yield runner_output.chain
257+
finally:
258+
# Streaming runner cleanup must happen after consumer
259+
# finishes iterating to avoid tearing down active streams.
260+
await close_runner_once()
261+
262+
event.set_result(
263+
MessageEventResult()
264+
.set_result_content_type(ResultContentType.STREAMING_RESULT)
265+
.set_async_stream(_stream_runner_chain()),
266+
)
267+
streaming_started = True
268+
yield
269+
270+
if runner.done():
271+
final_resp = runner.get_final_llm_resp()
272+
if final_resp and final_resp.result_chain:
273+
is_runner_error = (
274+
stream_has_runner_error or final_resp.role == "err"
275+
)
276+
_set_runner_error_extra(event, is_runner_error)
277+
event.set_result(
278+
MessageEventResult(
279+
chain=final_resp.result_chain.chain or [],
280+
result_content_type=ResultContentType.STREAMING_FINISH,
281+
),
282+
)
410283
else:
411-
async for _ in self._handle_non_streaming_runner(
412-
runner=runner,
413-
event=event,
284+
merged_chain: list = []
285+
has_intermediate_error = False
286+
async for output in run_third_party_agent(
287+
runner,
414288
stream_to_general=stream_to_general,
415289
custom_error_message=custom_error_message,
416290
):
291+
merged_chain.extend(output.chain.chain or [])
292+
if output.is_error:
293+
has_intermediate_error = True
294+
yield
295+
296+
final_resp = runner.get_final_llm_resp()
297+
if not final_resp or not final_resp.result_chain:
298+
if merged_chain:
299+
logger.warning(
300+
"Agent Runner returned no final response, fallback to streamed error/result chain."
301+
)
302+
_set_runner_error_extra(event, has_intermediate_error)
303+
event.set_result(
304+
MessageEventResult(
305+
chain=merged_chain,
306+
result_content_type=(
307+
ResultContentType.AGENT_RUNNER_ERROR
308+
if has_intermediate_error
309+
else ResultContentType.LLM_RESULT
310+
),
311+
),
312+
)
313+
else:
314+
logger.warning("Agent Runner 未返回最终结果。")
315+
fallback_error_chain = MessageChain().message(
316+
"Agent Runner did not return any result.",
317+
)
318+
_set_runner_error_extra(event, True)
319+
event.set_result(
320+
MessageEventResult(
321+
chain=fallback_error_chain.chain or [],
322+
result_content_type=ResultContentType.AGENT_RUNNER_ERROR,
323+
),
324+
)
417325
yield
326+
else:
327+
is_runner_error = has_intermediate_error or final_resp.role == "err"
328+
_set_runner_error_extra(event, is_runner_error)
329+
event.set_result(
330+
MessageEventResult(
331+
chain=final_resp.result_chain.chain or [],
332+
result_content_type=(
333+
ResultContentType.AGENT_RUNNER_ERROR
334+
if is_runner_error
335+
else ResultContentType.LLM_RESULT
336+
),
337+
),
338+
)
339+
yield
340+
finally:
341+
if not streaming_started:
342+
await close_runner_once()
418343

419344
asyncio.create_task(
420345
Metric.upload(

0 commit comments

Comments
 (0)