-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix: harden OpenAI attachment recovery #7004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
c469166
0762ff0
07ca792
69a4231
3df5603
d21b35f
3280a1d
950a214
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,15 @@ | ||
| import asyncio | ||
| import base64 | ||
| import copy | ||
| import inspect | ||
| import json | ||
| import random | ||
| import re | ||
| from collections.abc import AsyncGenerator | ||
| from io import BytesIO | ||
| from pathlib import Path | ||
| from typing import Any | ||
| from urllib.parse import unquote, urlparse | ||
|
|
||
| import httpx | ||
| from openai import AsyncAzureOpenAI, AsyncOpenAI | ||
|
|
@@ -14,6 +18,8 @@ | |
| from openai.types.chat.chat_completion import ChatCompletion | ||
| from openai.types.chat.chat_completion_chunk import ChatCompletionChunk | ||
| from openai.types.completion_usage import CompletionUsage | ||
| from PIL import Image as PILImage | ||
| from PIL import UnidentifiedImageError | ||
|
|
||
| import astrbot.core.message.components as Comp | ||
| from astrbot import logger | ||
|
|
@@ -133,6 +139,191 @@ def _context_contains_image(contexts: list[dict]) -> bool: | |
| return True | ||
| return False | ||
|
|
||
| @staticmethod | ||
| def _get_error_info(error: Exception) -> tuple[str | None, str | None]: | ||
| body = getattr(error, "body", None) | ||
| if not isinstance(body, dict): | ||
| return None, None | ||
|
|
||
| err_obj = body.get("error") | ||
| if not isinstance(err_obj, dict): | ||
| return None, None | ||
|
|
||
| code = err_obj.get("code") | ||
| message = err_obj.get("message") | ||
| return ( | ||
| code.lower() if isinstance(code, str) else None, | ||
| message.lower() if isinstance(message, str) else None, | ||
| ) | ||
|
|
||
| def _is_invalid_attachment_error(self, error: Exception) -> bool: | ||
| code, message = self._get_error_info(error) | ||
|
|
||
| candidates = [ | ||
| candidate.lower() | ||
| for candidate in self._extract_error_text_candidates(error) | ||
| ] | ||
| if message: | ||
| candidates.append(message) | ||
|
|
||
| for candidate in candidates: | ||
| if "invalid_attachment" in candidate: | ||
| return True | ||
| if "download attachment" in candidate and "404" in candidate: | ||
| return True | ||
| return code == "invalid_attachment" | ||
|
|
||
| @staticmethod | ||
| def _image_format_to_mime_type(image_format: str | None) -> str: | ||
| return { | ||
| "JPEG": "image/jpeg", | ||
| "PNG": "image/png", | ||
| "GIF": "image/gif", | ||
| "WEBP": "image/webp", | ||
| "BMP": "image/bmp", | ||
| }.get(str(image_format or "").upper(), "image/jpeg") | ||
|
|
||
| @staticmethod | ||
| def _read_file_bytes( | ||
| image_path: str, | ||
| *, | ||
| suppress_errors: bool = True, | ||
| ) -> bytes | None: | ||
| try: | ||
| return Path(image_path).read_bytes() | ||
| except OSError: | ||
| if not suppress_errors: | ||
| raise | ||
| return None | ||
|
|
||
| @staticmethod | ||
| def _detect_image_format(image_bytes: bytes) -> str | None: | ||
| try: | ||
| with PILImage.open(BytesIO(image_bytes)) as image: | ||
| image.verify() | ||
| return str(image.format or "").upper() | ||
| except (OSError, UnidentifiedImageError): | ||
| return None | ||
|
|
||
| @classmethod | ||
| def _encode_image_file_to_data_url(cls, image_path: str) -> str | None: | ||
| image_bytes = cls._read_file_bytes(image_path) | ||
| if image_bytes is None: | ||
| return None | ||
|
|
||
| image_format = cls._detect_image_format(image_bytes) | ||
| if image_format is None: | ||
| return None | ||
|
|
||
| mime_type = cls._image_format_to_mime_type(image_format) | ||
| image_bs64 = base64.b64encode(image_bytes).decode("utf-8") | ||
| return f"data:{mime_type};base64,{image_bs64}" | ||
|
|
||
|
Comment on lines
+195
to
+203
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return PILImage.MIME.get(image_format, "image/jpeg") |
||
| @staticmethod | ||
| def _file_uri_to_path(file_uri: str) -> str: | ||
| parsed = urlparse(file_uri) | ||
| if parsed.scheme != "file": | ||
| return file_uri | ||
|
|
||
| netloc = unquote(parsed.netloc or "") | ||
| path = unquote(parsed.path or "") | ||
| if re.fullmatch(r"[A-Za-z]:", netloc): | ||
| return str(Path(f"{netloc}{path}")) | ||
| if re.match(r"^/[A-Za-z]:/", path): | ||
| path = path[1:] | ||
| if netloc and netloc != "localhost": | ||
| path = f"//{netloc}{path}" | ||
| return str(Path(path)) | ||
|
|
||
| async def _load_image_data(self, image_url: str) -> str | None: | ||
| if image_url.startswith("base64://"): | ||
| return await self.encode_image_bs64(image_url) | ||
|
|
||
| if image_url.startswith("http"): | ||
| image_path = await download_image_by_url(image_url) | ||
| elif image_url.startswith("file://"): | ||
| image_path = self._file_uri_to_path(image_url) | ||
| else: | ||
| image_path = image_url | ||
|
|
||
| return self._encode_image_file_to_data_url(image_path) | ||
|
|
||
| async def _resolve_image_part( | ||
| self, | ||
| image_url: str, | ||
| *, | ||
| image_detail: str | None = None, | ||
| ) -> dict | None: | ||
| if image_url.startswith("data:"): | ||
| image_payload = {"url": image_url} | ||
| else: | ||
| image_data = await self._load_image_data(image_url) | ||
| if not image_data: | ||
| logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") | ||
| return None | ||
| image_payload = {"url": image_data} | ||
|
|
||
| if image_detail: | ||
| image_payload["detail"] = image_detail | ||
| return { | ||
| "type": "image_url", | ||
| "image_url": image_payload, | ||
| } | ||
|
|
||
| async def _materialize_context_image_parts(self, context_query: list[dict]) -> None: | ||
| for message in context_query: | ||
| content = message.get("content") | ||
| if not isinstance(content, list): | ||
| continue | ||
|
|
||
| new_content: list[dict] = [] | ||
| content_changed = False | ||
| for part in content: | ||
| if not isinstance(part, dict) or part.get("type") != "image_url": | ||
| new_content.append(part) | ||
| continue | ||
|
|
||
| image_url_data = part.get("image_url") | ||
| if not isinstance(image_url_data, dict): | ||
| logger.warning("图片内容块格式无效,将保留原始内容。") | ||
| new_content.append(part) | ||
| continue | ||
|
|
||
| url = image_url_data.get("url") | ||
| if not isinstance(url, str) or not url: | ||
| logger.warning("图片内容块缺少有效 URL,将保留原始内容。") | ||
| new_content.append(part) | ||
| continue | ||
|
|
||
| image_detail = image_url_data.get("detail") | ||
| if not isinstance(image_detail, str): | ||
| image_detail = None | ||
|
|
||
| try: | ||
| resolved_part = await self._resolve_image_part( | ||
| url, image_detail=image_detail | ||
| ) | ||
| except Exception as exc: | ||
| logger.warning( | ||
| "图片 %s 预处理失败,将保留原始内容。错误: %s", | ||
| url, | ||
| exc, | ||
| ) | ||
| new_content.append(part) | ||
| continue | ||
|
|
||
| if resolved_part is None: | ||
| new_content.append(part) | ||
| continue | ||
|
|
||
| new_content.append(resolved_part) | ||
| if resolved_part != part: | ||
| content_changed = True | ||
|
|
||
| if not content_changed: | ||
| continue | ||
| message["content"] = new_content | ||
|
|
||
| async def _fallback_to_text_only_and_retry( | ||
| self, | ||
| payloads: dict, | ||
|
|
@@ -594,7 +785,7 @@ async def _prepare_chat_payload( | |
| new_record = await self.assemble_context( | ||
| prompt, image_urls, extra_user_content_parts | ||
| ) | ||
| context_query = self._ensure_message_to_dicts(contexts) | ||
| context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts)) | ||
| if new_record: | ||
| context_query.append(new_record) | ||
| if system_prompt: | ||
|
|
@@ -612,6 +803,8 @@ async def _prepare_chat_payload( | |
| for tcr in tool_calls_result: | ||
| context_query.extend(tcr.to_openai_messages()) | ||
|
|
||
| await self._materialize_context_image_parts(context_query) | ||
|
|
||
| model = model or self.get_model() | ||
|
|
||
| payloads = {"messages": context_query, "model": model} | ||
|
|
@@ -712,6 +905,18 @@ async def _handle_api_error( | |
| "image_content_moderated", | ||
| image_fallback_used=True, | ||
| ) | ||
| if self._is_invalid_attachment_error(e): | ||
| if image_fallback_used or not self._context_contains_image(context_query): | ||
| raise e | ||
| return await self._fallback_to_text_only_and_retry( | ||
| payloads, | ||
| context_query, | ||
| chosen_key, | ||
| available_api_keys, | ||
| func_tool, | ||
| "invalid_attachment", | ||
| image_fallback_used=True, | ||
| ) | ||
|
|
||
| if ( | ||
| "Function calling is not enabled" in str(e) | ||
|
|
@@ -913,23 +1118,6 @@ async def assemble_context( | |
| ) -> dict: | ||
| """组装成符合 OpenAI 格式的 role 为 user 的消息段""" | ||
|
|
||
| async def resolve_image_part(image_url: str) -> dict | None: | ||
| if image_url.startswith("http"): | ||
| image_path = await download_image_by_url(image_url) | ||
| image_data = await self.encode_image_bs64(image_path) | ||
| elif image_url.startswith("file:///"): | ||
| image_path = image_url.replace("file:///", "") | ||
| image_data = await self.encode_image_bs64(image_path) | ||
| else: | ||
| image_data = await self.encode_image_bs64(image_url) | ||
| if not image_data: | ||
| logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") | ||
| return None | ||
| return { | ||
| "type": "image_url", | ||
| "image_url": {"url": image_data}, | ||
| } | ||
|
|
||
| # 构建内容块列表 | ||
| content_blocks = [] | ||
|
|
||
|
|
@@ -949,7 +1137,9 @@ async def resolve_image_part(image_url: str) -> dict | None: | |
| if isinstance(part, TextPart): | ||
| content_blocks.append({"type": "text", "text": part.text}) | ||
| elif isinstance(part, ImageURLPart): | ||
| image_part = await resolve_image_part(part.image_url.url) | ||
| image_part = await self._resolve_image_part( | ||
| part.image_url.url, | ||
| ) | ||
| if image_part: | ||
| content_blocks.append(image_part) | ||
| else: | ||
|
|
@@ -958,7 +1148,7 @@ async def resolve_image_part(image_url: str) -> dict | None: | |
| # 3. 图片内容 | ||
| if image_urls: | ||
| for image_url in image_urls: | ||
| image_part = await resolve_image_part(image_url) | ||
| image_part = await self._resolve_image_part(image_url) | ||
| if image_part: | ||
| content_blocks.append(image_part) | ||
|
|
||
|
|
@@ -979,9 +1169,16 @@ async def encode_image_bs64(self, image_url: str) -> str: | |
| """将图片转换为 base64""" | ||
| if image_url.startswith("base64://"): | ||
| return image_url.replace("base64://", "data:image/jpeg;base64,") | ||
| with open(image_url, "rb") as f: | ||
| image_bs64 = base64.b64encode(f.read()).decode("utf-8") | ||
| return "data:image/jpeg;base64," + image_bs64 | ||
| image_bytes = self._read_file_bytes(image_url, suppress_errors=False) | ||
| if image_bytes is None: | ||
| raise FileNotFoundError(image_url) | ||
|
|
||
|
sourcery-ai[bot] marked this conversation as resolved.
Outdated
|
||
| image_format = self._detect_image_format(image_bytes) | ||
| if image_format is None: | ||
| raise ValueError(f"Invalid image file: {image_url}") | ||
| mime_type = self._image_format_to_mime_type(image_format) | ||
| image_bs64 = base64.b64encode(image_bytes).decode("utf-8") | ||
| return f"data:{mime_type};base64,{image_bs64}" | ||
|
|
||
| async def terminate(self): | ||
| if self.client: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.