Skip to content

Commit f25b889

Browse files
authored
Merge branch 'AstrBotDevs:master' into feat/sdk-integration
2 parents 654acd8 + cd4e999 commit f25b889

7 files changed

Lines changed: 838 additions & 35 deletions

File tree

astrbot/core/provider/sources/anthropic_source.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ async def text_chat(
515515

516516
model = model or self.get_model()
517517

518-
payloads = {"messages": new_messages, "model": model}
518+
payloads = {**kwargs, "messages": new_messages, "model": model}
519519

520520
# Anthropic has a different way of handling system prompts
521521
if system_prompt:
@@ -571,7 +571,7 @@ async def text_chat_stream(
571571

572572
model = model or self.get_model()
573573

574-
payloads = {"messages": new_messages, "model": model}
574+
payloads = {**kwargs, "messages": new_messages, "model": model}
575575

576576
# Anthropic has a different way of handling system prompts
577577
if system_prompt:

astrbot/core/provider/sources/gemini_source.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ async def text_chat(
757757

758758
model = model or self.get_model()
759759

760-
payloads = {"messages": context_query, "model": model}
760+
payloads = {**kwargs, "messages": context_query, "model": model}
761761

762762
retry = 10
763763
keys = self.api_keys.copy()
@@ -812,7 +812,7 @@ async def text_chat_stream(
812812

813813
model = model or self.get_model()
814814

815-
payloads = {"messages": context_query, "model": model}
815+
payloads = {**kwargs, "messages": context_query, "model": model}
816816

817817
retry = 10
818818
keys = self.api_keys.copy()

astrbot/core/provider/sources/openai_source.py

Lines changed: 212 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import asyncio
22
import base64
3+
import copy
34
import inspect
45
import json
56
import random
67
import re
78
from collections.abc import AsyncGenerator
8-
from typing import Any
9+
from io import BytesIO
10+
from pathlib import Path
11+
from typing import Any, Literal
12+
from urllib.parse import unquote, urlparse
913

1014
import httpx
1115
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -14,6 +18,8 @@
1418
from openai.types.chat.chat_completion import ChatCompletion
1519
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1620
from openai.types.completion_usage import CompletionUsage
21+
from PIL import Image as PILImage
22+
from PIL import UnidentifiedImageError
1723

1824
import astrbot.core.message.components as Comp
1925
from astrbot import logger
@@ -133,6 +139,186 @@ def _context_contains_image(contexts: list[dict]) -> bool:
133139
return True
134140
return False
135141

142+
def _is_invalid_attachment_error(self, error: Exception) -> bool:
143+
body = getattr(error, "body", None)
144+
code: str | None = None
145+
message: str | None = None
146+
if isinstance(body, dict):
147+
err_obj = body.get("error")
148+
if isinstance(err_obj, dict):
149+
raw_code = err_obj.get("code")
150+
raw_message = err_obj.get("message")
151+
code = raw_code.lower() if isinstance(raw_code, str) else None
152+
message = raw_message.lower() if isinstance(raw_message, str) else None
153+
154+
if code == "invalid_attachment":
155+
return True
156+
157+
text_sources: list[str] = []
158+
if message:
159+
text_sources.append(message)
160+
if code:
161+
text_sources.append(code)
162+
text_sources.extend(map(str, self._extract_error_text_candidates(error)))
163+
164+
error_text = " ".join(text.lower() for text in text_sources if text)
165+
if "invalid_attachment" in error_text:
166+
return True
167+
if "download attachment" in error_text and "404" in error_text:
168+
return True
169+
return False
170+
171+
@classmethod
172+
def _encode_image_file_to_data_url(
173+
cls,
174+
image_path: str,
175+
*,
176+
mode: Literal["safe", "strict"],
177+
) -> str | None:
178+
try:
179+
image_bytes = Path(image_path).read_bytes()
180+
except OSError:
181+
if mode == "strict":
182+
raise
183+
return None
184+
185+
try:
186+
with PILImage.open(BytesIO(image_bytes)) as image:
187+
image.verify()
188+
image_format = str(image.format or "").upper()
189+
except (OSError, UnidentifiedImageError):
190+
if mode == "strict":
191+
raise ValueError(f"Invalid image file: {image_path}")
192+
return None
193+
194+
mime_type = {
195+
"JPEG": "image/jpeg",
196+
"PNG": "image/png",
197+
"GIF": "image/gif",
198+
"WEBP": "image/webp",
199+
"BMP": "image/bmp",
200+
}.get(image_format, "image/jpeg")
201+
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
202+
return f"data:{mime_type};base64,{image_bs64}"
203+
204+
@staticmethod
205+
def _file_uri_to_path(file_uri: str) -> str:
206+
"""Normalize file URIs to paths.
207+
208+
`file://localhost/...` and drive-letter forms are treated as local paths.
209+
Other non-empty hosts are preserved as UNC-style paths.
210+
"""
211+
parsed = urlparse(file_uri)
212+
if parsed.scheme != "file":
213+
return file_uri
214+
215+
netloc = unquote(parsed.netloc or "")
216+
path = unquote(parsed.path or "")
217+
if re.fullmatch(r"[A-Za-z]:", netloc):
218+
return str(Path(f"{netloc}{path}"))
219+
if re.match(r"^/[A-Za-z]:/", path):
220+
path = path[1:]
221+
if netloc and netloc != "localhost":
222+
path = f"//{netloc}{path}"
223+
return str(Path(path))
224+
225+
async def _image_ref_to_data_url(
226+
self,
227+
image_ref: str,
228+
*,
229+
mode: Literal["safe", "strict"] = "safe",
230+
) -> str | None:
231+
if image_ref.startswith("base64://"):
232+
return image_ref.replace("base64://", "data:image/jpeg;base64,")
233+
234+
if image_ref.startswith("http"):
235+
image_path = await download_image_by_url(image_ref)
236+
elif image_ref.startswith("file://"):
237+
image_path = self._file_uri_to_path(image_ref)
238+
else:
239+
image_path = image_ref
240+
241+
return self._encode_image_file_to_data_url(
242+
image_path,
243+
mode=mode,
244+
)
245+
246+
async def _resolve_image_part(
247+
self,
248+
image_url: str,
249+
*,
250+
image_detail: str | None = None,
251+
) -> dict | None:
252+
if image_url.startswith("data:"):
253+
image_payload = {"url": image_url}
254+
else:
255+
image_data = await self._image_ref_to_data_url(image_url, mode="safe")
256+
if not image_data:
257+
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
258+
return None
259+
image_payload = {"url": image_data}
260+
261+
if image_detail:
262+
image_payload["detail"] = image_detail
263+
return {
264+
"type": "image_url",
265+
"image_url": image_payload,
266+
}
267+
268+
def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]:
269+
if not isinstance(part, dict) or part.get("type") != "image_url":
270+
return None, None
271+
272+
image_url_data = part.get("image_url")
273+
if not isinstance(image_url_data, dict):
274+
logger.warning("图片内容块格式无效,将保留原始内容。")
275+
return None, None
276+
277+
url = image_url_data.get("url")
278+
if not isinstance(url, str) or not url:
279+
logger.warning("图片内容块缺少有效 URL,将保留原始内容。")
280+
return None, None
281+
282+
image_detail = image_url_data.get("detail")
283+
if not isinstance(image_detail, str):
284+
image_detail = None
285+
return url, image_detail
286+
287+
async def _transform_content_part(self, part: dict) -> dict:
288+
url, image_detail = self._extract_image_part_info(part)
289+
if not url:
290+
return part
291+
292+
try:
293+
resolved_part = await self._resolve_image_part(
294+
url, image_detail=image_detail
295+
)
296+
except Exception as exc:
297+
logger.warning(
298+
"图片 %s 预处理失败,将保留原始内容。错误: %s",
299+
url,
300+
exc,
301+
)
302+
return part
303+
304+
return resolved_part or part
305+
306+
async def _materialize_message_image_parts(self, message: dict) -> dict:
307+
content = message.get("content")
308+
if not isinstance(content, list):
309+
return {**message}
310+
311+
new_content = [await self._transform_content_part(part) for part in content]
312+
return {**message, "content": new_content}
313+
314+
async def _materialize_context_image_parts(
315+
self, context_query: list[dict]
316+
) -> list[dict]:
317+
return [
318+
await self._materialize_message_image_parts(message)
319+
for message in context_query
320+
]
321+
136322
async def _fallback_to_text_only_and_retry(
137323
self,
138324
payloads: dict,
@@ -604,7 +790,7 @@ async def _prepare_chat_payload(
604790
new_record = await self.assemble_context(
605791
prompt, image_urls, extra_user_content_parts
606792
)
607-
context_query = self._ensure_message_to_dicts(contexts)
793+
context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts))
608794
if new_record:
609795
context_query.append(new_record)
610796
if system_prompt:
@@ -622,9 +808,11 @@ async def _prepare_chat_payload(
622808
for tcr in tool_calls_result:
623809
context_query.extend(tcr.to_openai_messages())
624810

625-
model = model or self.get_model()
811+
if self._context_contains_image(context_query):
812+
context_query = await self._materialize_context_image_parts(context_query)
626813

627-
payloads = {"messages": context_query, "model": model}
814+
model = model or self.get_model()
815+
payloads = {**kwargs, "messages": context_query, "model": model}
628816

629817
self._finally_convert_payload(payloads)
630818

@@ -722,6 +910,18 @@ async def _handle_api_error(
722910
"image_content_moderated",
723911
image_fallback_used=True,
724912
)
913+
if self._is_invalid_attachment_error(e):
914+
if image_fallback_used or not self._context_contains_image(context_query):
915+
raise e
916+
return await self._fallback_to_text_only_and_retry(
917+
payloads,
918+
context_query,
919+
chosen_key,
920+
available_api_keys,
921+
func_tool,
922+
"invalid_attachment",
923+
image_fallback_used=True,
924+
)
725925

726926
if (
727927
"Function calling is not enabled" in str(e)
@@ -923,23 +1123,6 @@ async def assemble_context(
9231123
) -> dict:
9241124
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
9251125

926-
async def resolve_image_part(image_url: str) -> dict | None:
927-
if image_url.startswith("http"):
928-
image_path = await download_image_by_url(image_url)
929-
image_data = await self.encode_image_bs64(image_path)
930-
elif image_url.startswith("file:///"):
931-
image_path = image_url.replace("file:///", "")
932-
image_data = await self.encode_image_bs64(image_path)
933-
else:
934-
image_data = await self.encode_image_bs64(image_url)
935-
if not image_data:
936-
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
937-
return None
938-
return {
939-
"type": "image_url",
940-
"image_url": {"url": image_data},
941-
}
942-
9431126
# 构建内容块列表
9441127
content_blocks = []
9451128

@@ -959,7 +1142,9 @@ async def resolve_image_part(image_url: str) -> dict | None:
9591142
if isinstance(part, TextPart):
9601143
content_blocks.append({"type": "text", "text": part.text})
9611144
elif isinstance(part, ImageURLPart):
962-
image_part = await resolve_image_part(part.image_url.url)
1145+
image_part = await self._resolve_image_part(
1146+
part.image_url.url,
1147+
)
9631148
if image_part:
9641149
content_blocks.append(image_part)
9651150
else:
@@ -968,7 +1153,7 @@ async def resolve_image_part(image_url: str) -> dict | None:
9681153
# 3. 图片内容
9691154
if image_urls:
9701155
for image_url in image_urls:
971-
image_part = await resolve_image_part(image_url)
1156+
image_part = await self._resolve_image_part(image_url)
9721157
if image_part:
9731158
content_blocks.append(image_part)
9741159

@@ -987,11 +1172,10 @@ async def resolve_image_part(image_url: str) -> dict | None:
9871172

9881173
async def encode_image_bs64(self, image_url: str) -> str:
9891174
"""将图片转换为 base64"""
990-
if image_url.startswith("base64://"):
991-
return image_url.replace("base64://", "data:image/jpeg;base64,")
992-
with open(image_url, "rb") as f:
993-
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
994-
return "data:image/jpeg;base64," + image_bs64
1175+
image_data = await self._image_ref_to_data_url(image_url, mode="strict")
1176+
if image_data is None:
1177+
raise RuntimeError(f"Failed to encode image data: {image_url}")
1178+
return image_data
9951179

9961180
async def terminate(self):
9971181
if self.client:

dashboard/src/scss/_override.scss

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ html {
1717
flex: unset;
1818
}
1919

20+
.v-overlay.v-snackbar {
21+
--v-layout-left: 0px !important;
22+
--v-layout-right: 0px !important;
23+
}
24+
2025
.customizer-btn .icon {
2126
animation: progress-circular-rotate 1.4s linear infinite;
2227
transform-origin: center center;
@@ -34,3 +39,10 @@ html {
3439
transform: rotate(270deg);
3540
}
3641
}
42+
43+
pre, code, .markdown pre, .markdown code, .release-notes pre, .release-notes code {
44+
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, "Roboto Mono", "Helvetica Neue", monospace;
45+
color: var(--astrbot-code-color);
46+
}
47+
48+

dashboard/src/scss/_variables.scss

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@ $font-size-root: 1rem;
1111
$border-radius-root: 8px;
1212
$cjk-sans-fallback: 'PingFang SC', 'Hiragino Sans GB', 'Noto Sans CJK SC', 'Microsoft YaHei' !default;
1313
$cjk-mono-fallback: 'PingFang SC', 'PingFang TC', 'Hiragino Sans GB', 'Noto Sans CJK SC', 'Microsoft YaHei' !default;
14+
$code-text-color: #111827 !default;
1415

1516
:root {
1617
--astrbot-font-cjk-sans: #{$cjk-sans-fallback};
1718
--astrbot-font-cjk-mono: #{$cjk-mono-fallback};
19+
--astrbot-code-color: #{$code-text-color};
1820
}
1921

2022
$body-font-family: 'Roboto', $cjk-sans-fallback, sans-serif !default;

0 commit comments

Comments
 (0)