Skip to content

Commit cd4e999

Browse files
authored
fix: harden OpenAI attachment recovery (#7004)
* fix: harden OpenAI attachment recovery * fix: refine OpenAI image loading * fix: restore OpenAI image encoding errors * refactor: streamline OpenAI image helpers * refactor: simplify OpenAI attachment helpers * refactor: simplify OpenAI helper flow * refactor: clarify OpenAI image modes * refactor: reduce OpenAI materialization copies
1 parent 6db9aef commit cd4e999

2 files changed

Lines changed: 816 additions & 28 deletions

File tree

astrbot/core/provider/sources/openai_source.py

Lines changed: 211 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import asyncio
22
import base64
3+
import copy
34
import inspect
45
import json
56
import random
67
import re
78
from collections.abc import AsyncGenerator
8-
from typing import Any
9+
from io import BytesIO
10+
from pathlib import Path
11+
from typing import Any, Literal
12+
from urllib.parse import unquote, urlparse
913

1014
import httpx
1115
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -14,6 +18,8 @@
1418
from openai.types.chat.chat_completion import ChatCompletion
1519
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1620
from openai.types.completion_usage import CompletionUsage
21+
from PIL import Image as PILImage
22+
from PIL import UnidentifiedImageError
1723

1824
import astrbot.core.message.components as Comp
1925
from astrbot import logger
@@ -133,6 +139,186 @@ def _context_contains_image(contexts: list[dict]) -> bool:
133139
return True
134140
return False
135141

142+
def _is_invalid_attachment_error(self, error: Exception) -> bool:
143+
body = getattr(error, "body", None)
144+
code: str | None = None
145+
message: str | None = None
146+
if isinstance(body, dict):
147+
err_obj = body.get("error")
148+
if isinstance(err_obj, dict):
149+
raw_code = err_obj.get("code")
150+
raw_message = err_obj.get("message")
151+
code = raw_code.lower() if isinstance(raw_code, str) else None
152+
message = raw_message.lower() if isinstance(raw_message, str) else None
153+
154+
if code == "invalid_attachment":
155+
return True
156+
157+
text_sources: list[str] = []
158+
if message:
159+
text_sources.append(message)
160+
if code:
161+
text_sources.append(code)
162+
text_sources.extend(map(str, self._extract_error_text_candidates(error)))
163+
164+
error_text = " ".join(text.lower() for text in text_sources if text)
165+
if "invalid_attachment" in error_text:
166+
return True
167+
if "download attachment" in error_text and "404" in error_text:
168+
return True
169+
return False
170+
171+
@classmethod
172+
def _encode_image_file_to_data_url(
173+
cls,
174+
image_path: str,
175+
*,
176+
mode: Literal["safe", "strict"],
177+
) -> str | None:
178+
try:
179+
image_bytes = Path(image_path).read_bytes()
180+
except OSError:
181+
if mode == "strict":
182+
raise
183+
return None
184+
185+
try:
186+
with PILImage.open(BytesIO(image_bytes)) as image:
187+
image.verify()
188+
image_format = str(image.format or "").upper()
189+
except (OSError, UnidentifiedImageError):
190+
if mode == "strict":
191+
raise ValueError(f"Invalid image file: {image_path}")
192+
return None
193+
194+
mime_type = {
195+
"JPEG": "image/jpeg",
196+
"PNG": "image/png",
197+
"GIF": "image/gif",
198+
"WEBP": "image/webp",
199+
"BMP": "image/bmp",
200+
}.get(image_format, "image/jpeg")
201+
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
202+
return f"data:{mime_type};base64,{image_bs64}"
203+
204+
@staticmethod
205+
def _file_uri_to_path(file_uri: str) -> str:
206+
"""Normalize file URIs to paths.
207+
208+
`file://localhost/...` and drive-letter forms are treated as local paths.
209+
Other non-empty hosts are preserved as UNC-style paths.
210+
"""
211+
parsed = urlparse(file_uri)
212+
if parsed.scheme != "file":
213+
return file_uri
214+
215+
netloc = unquote(parsed.netloc or "")
216+
path = unquote(parsed.path or "")
217+
if re.fullmatch(r"[A-Za-z]:", netloc):
218+
return str(Path(f"{netloc}{path}"))
219+
if re.match(r"^/[A-Za-z]:/", path):
220+
path = path[1:]
221+
if netloc and netloc != "localhost":
222+
path = f"//{netloc}{path}"
223+
return str(Path(path))
224+
225+
async def _image_ref_to_data_url(
226+
self,
227+
image_ref: str,
228+
*,
229+
mode: Literal["safe", "strict"] = "safe",
230+
) -> str | None:
231+
if image_ref.startswith("base64://"):
232+
return image_ref.replace("base64://", "data:image/jpeg;base64,")
233+
234+
if image_ref.startswith("http"):
235+
image_path = await download_image_by_url(image_ref)
236+
elif image_ref.startswith("file://"):
237+
image_path = self._file_uri_to_path(image_ref)
238+
else:
239+
image_path = image_ref
240+
241+
return self._encode_image_file_to_data_url(
242+
image_path,
243+
mode=mode,
244+
)
245+
246+
async def _resolve_image_part(
247+
self,
248+
image_url: str,
249+
*,
250+
image_detail: str | None = None,
251+
) -> dict | None:
252+
if image_url.startswith("data:"):
253+
image_payload = {"url": image_url}
254+
else:
255+
image_data = await self._image_ref_to_data_url(image_url, mode="safe")
256+
if not image_data:
257+
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
258+
return None
259+
image_payload = {"url": image_data}
260+
261+
if image_detail:
262+
image_payload["detail"] = image_detail
263+
return {
264+
"type": "image_url",
265+
"image_url": image_payload,
266+
}
267+
268+
def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]:
269+
if not isinstance(part, dict) or part.get("type") != "image_url":
270+
return None, None
271+
272+
image_url_data = part.get("image_url")
273+
if not isinstance(image_url_data, dict):
274+
logger.warning("图片内容块格式无效,将保留原始内容。")
275+
return None, None
276+
277+
url = image_url_data.get("url")
278+
if not isinstance(url, str) or not url:
279+
logger.warning("图片内容块缺少有效 URL,将保留原始内容。")
280+
return None, None
281+
282+
image_detail = image_url_data.get("detail")
283+
if not isinstance(image_detail, str):
284+
image_detail = None
285+
return url, image_detail
286+
287+
async def _transform_content_part(self, part: dict) -> dict:
288+
url, image_detail = self._extract_image_part_info(part)
289+
if not url:
290+
return part
291+
292+
try:
293+
resolved_part = await self._resolve_image_part(
294+
url, image_detail=image_detail
295+
)
296+
except Exception as exc:
297+
logger.warning(
298+
"图片 %s 预处理失败,将保留原始内容。错误: %s",
299+
url,
300+
exc,
301+
)
302+
return part
303+
304+
return resolved_part or part
305+
306+
async def _materialize_message_image_parts(self, message: dict) -> dict:
307+
content = message.get("content")
308+
if not isinstance(content, list):
309+
return {**message}
310+
311+
new_content = [await self._transform_content_part(part) for part in content]
312+
return {**message, "content": new_content}
313+
314+
async def _materialize_context_image_parts(
315+
self, context_query: list[dict]
316+
) -> list[dict]:
317+
return [
318+
await self._materialize_message_image_parts(message)
319+
for message in context_query
320+
]
321+
136322
async def _fallback_to_text_only_and_retry(
137323
self,
138324
payloads: dict,
@@ -604,7 +790,7 @@ async def _prepare_chat_payload(
604790
new_record = await self.assemble_context(
605791
prompt, image_urls, extra_user_content_parts
606792
)
607-
context_query = self._ensure_message_to_dicts(contexts)
793+
context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts))
608794
if new_record:
609795
context_query.append(new_record)
610796
if system_prompt:
@@ -622,6 +808,9 @@ async def _prepare_chat_payload(
622808
for tcr in tool_calls_result:
623809
context_query.extend(tcr.to_openai_messages())
624810

811+
if self._context_contains_image(context_query):
812+
context_query = await self._materialize_context_image_parts(context_query)
813+
625814
model = model or self.get_model()
626815
payloads = {**kwargs, "messages": context_query, "model": model}
627816

@@ -721,6 +910,18 @@ async def _handle_api_error(
721910
"image_content_moderated",
722911
image_fallback_used=True,
723912
)
913+
if self._is_invalid_attachment_error(e):
914+
if image_fallback_used or not self._context_contains_image(context_query):
915+
raise e
916+
return await self._fallback_to_text_only_and_retry(
917+
payloads,
918+
context_query,
919+
chosen_key,
920+
available_api_keys,
921+
func_tool,
922+
"invalid_attachment",
923+
image_fallback_used=True,
924+
)
724925

725926
if (
726927
"Function calling is not enabled" in str(e)
@@ -922,23 +1123,6 @@ async def assemble_context(
9221123
) -> dict:
9231124
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
9241125

925-
async def resolve_image_part(image_url: str) -> dict | None:
926-
if image_url.startswith("http"):
927-
image_path = await download_image_by_url(image_url)
928-
image_data = await self.encode_image_bs64(image_path)
929-
elif image_url.startswith("file:///"):
930-
image_path = image_url.replace("file:///", "")
931-
image_data = await self.encode_image_bs64(image_path)
932-
else:
933-
image_data = await self.encode_image_bs64(image_url)
934-
if not image_data:
935-
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
936-
return None
937-
return {
938-
"type": "image_url",
939-
"image_url": {"url": image_data},
940-
}
941-
9421126
# 构建内容块列表
9431127
content_blocks = []
9441128

@@ -958,7 +1142,9 @@ async def resolve_image_part(image_url: str) -> dict | None:
9581142
if isinstance(part, TextPart):
9591143
content_blocks.append({"type": "text", "text": part.text})
9601144
elif isinstance(part, ImageURLPart):
961-
image_part = await resolve_image_part(part.image_url.url)
1145+
image_part = await self._resolve_image_part(
1146+
part.image_url.url,
1147+
)
9621148
if image_part:
9631149
content_blocks.append(image_part)
9641150
else:
@@ -967,7 +1153,7 @@ async def resolve_image_part(image_url: str) -> dict | None:
9671153
# 3. 图片内容
9681154
if image_urls:
9691155
for image_url in image_urls:
970-
image_part = await resolve_image_part(image_url)
1156+
image_part = await self._resolve_image_part(image_url)
9711157
if image_part:
9721158
content_blocks.append(image_part)
9731159

@@ -986,11 +1172,10 @@ async def resolve_image_part(image_url: str) -> dict | None:
9861172

9871173
async def encode_image_bs64(self, image_url: str) -> str:
9881174
"""将图片转换为 base64"""
989-
if image_url.startswith("base64://"):
990-
return image_url.replace("base64://", "data:image/jpeg;base64,")
991-
with open(image_url, "rb") as f:
992-
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
993-
return "data:image/jpeg;base64," + image_bs64
1175+
image_data = await self._image_ref_to_data_url(image_url, mode="strict")
1176+
if image_data is None:
1177+
raise RuntimeError(f"Failed to encode image data: {image_url}")
1178+
return image_data
9941179

9951180
async def terminate(self):
9961181
if self.client:

0 commit comments

Comments
 (0)