-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix: harden OpenAI attachment recovery #7004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
c469166
0762ff0
07ca792
69a4231
3df5603
d21b35f
3280a1d
950a214
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,15 @@ | ||
| import asyncio | ||
| import base64 | ||
| import copy | ||
| import inspect | ||
| import json | ||
| import random | ||
| import re | ||
| from collections.abc import AsyncGenerator | ||
| from typing import Any | ||
| from io import BytesIO | ||
| from pathlib import Path | ||
| from typing import Any, Literal | ||
| from urllib.parse import unquote, urlparse | ||
|
|
||
| import httpx | ||
| from openai import AsyncAzureOpenAI, AsyncOpenAI | ||
|
|
@@ -14,6 +18,8 @@ | |
| from openai.types.chat.chat_completion import ChatCompletion | ||
| from openai.types.chat.chat_completion_chunk import ChatCompletionChunk | ||
| from openai.types.completion_usage import CompletionUsage | ||
| from PIL import Image as PILImage | ||
| from PIL import UnidentifiedImageError | ||
|
|
||
| import astrbot.core.message.components as Comp | ||
| from astrbot import logger | ||
|
|
@@ -133,6 +139,189 @@ def _context_contains_image(contexts: list[dict]) -> bool: | |
| return True | ||
| return False | ||
|
|
||
| def _is_invalid_attachment_error(self, error: Exception) -> bool: | ||
|
sourcery-ai[bot] marked this conversation as resolved.
sourcery-ai[bot] marked this conversation as resolved.
|
||
| body = getattr(error, "body", None) | ||
| code: str | None = None | ||
| message: str | None = None | ||
| if isinstance(body, dict): | ||
| err_obj = body.get("error") | ||
| if isinstance(err_obj, dict): | ||
| raw_code = err_obj.get("code") | ||
| raw_message = err_obj.get("message") | ||
| code = raw_code.lower() if isinstance(raw_code, str) else None | ||
| message = raw_message.lower() if isinstance(raw_message, str) else None | ||
|
|
||
| parts: list[str] = [] | ||
| if code: | ||
| parts.append(code) | ||
| if message: | ||
| parts.append(message) | ||
| parts.extend(map(str, self._extract_error_text_candidates(error))) | ||
|
|
||
| error_text = " ".join(part.lower() for part in parts if part) | ||
| if "invalid_attachment" in error_text: | ||
| return True | ||
| if "download attachment" in error_text and "404" in error_text: | ||
| return True | ||
| return code == "invalid_attachment" | ||
|
|
||
| @classmethod | ||
|
sourcery-ai[bot] marked this conversation as resolved.
|
||
| def _encode_image_file_to_data_url( | ||
| cls, | ||
| image_path: str, | ||
| *, | ||
| mode: Literal["safe", "strict"], | ||
| ) -> str | None: | ||
| try: | ||
| image_bytes = Path(image_path).read_bytes() | ||
| except OSError: | ||
| if mode == "strict": | ||
| raise | ||
| return None | ||
|
|
||
| try: | ||
| with PILImage.open(BytesIO(image_bytes)) as image: | ||
| image.verify() | ||
| image_format = str(image.format or "").upper() | ||
| except (OSError, UnidentifiedImageError): | ||
| if mode == "strict": | ||
| raise ValueError(f"Invalid image file: {image_path}") | ||
| return None | ||
|
|
||
| mime_type = { | ||
| "JPEG": "image/jpeg", | ||
| "PNG": "image/png", | ||
| "GIF": "image/gif", | ||
| "WEBP": "image/webp", | ||
| "BMP": "image/bmp", | ||
| }.get(image_format, "image/jpeg") | ||
| image_bs64 = base64.b64encode(image_bytes).decode("utf-8") | ||
| return f"data:{mime_type};base64,{image_bs64}" | ||
|
|
||
|
Comment on lines
+195
to
+203
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The return PILImage.MIME.get(image_format, "image/jpeg") |
||
| @staticmethod | ||
| def _file_uri_to_path(file_uri: str) -> str: | ||
| """Normalize file URIs to paths. | ||
|
|
||
| `file://localhost/...` and drive-letter forms are treated as local paths. | ||
| Other non-empty hosts are preserved as UNC-style paths. | ||
| """ | ||
| parsed = urlparse(file_uri) | ||
| if parsed.scheme != "file": | ||
| return file_uri | ||
|
|
||
| netloc = unquote(parsed.netloc or "") | ||
| path = unquote(parsed.path or "") | ||
| if re.fullmatch(r"[A-Za-z]:", netloc): | ||
| return str(Path(f"{netloc}{path}")) | ||
| if re.match(r"^/[A-Za-z]:/", path): | ||
| path = path[1:] | ||
| if netloc and netloc != "localhost": | ||
| path = f"//{netloc}{path}" | ||
| return str(Path(path)) | ||
|
|
||
| async def _image_ref_to_data_url( | ||
| self, | ||
| image_ref: str, | ||
| *, | ||
| mode: Literal["safe", "strict"] = "safe", | ||
| ) -> str | None: | ||
| if image_ref.startswith("base64://"): | ||
| return image_ref.replace("base64://", "data:image/jpeg;base64,") | ||
|
|
||
| if image_ref.startswith("http"): | ||
| image_path = await download_image_by_url(image_ref) | ||
| elif image_ref.startswith("file://"): | ||
| image_path = self._file_uri_to_path(image_ref) | ||
| else: | ||
| image_path = image_ref | ||
|
|
||
| return self._encode_image_file_to_data_url( | ||
| image_path, | ||
| mode=mode, | ||
| ) | ||
|
|
||
| async def _resolve_image_part( | ||
| self, | ||
| image_url: str, | ||
| *, | ||
| image_detail: str | None = None, | ||
| ) -> dict | None: | ||
| if image_url.startswith("data:"): | ||
| image_payload = {"url": image_url} | ||
| else: | ||
| image_data = await self._image_ref_to_data_url(image_url, mode="safe") | ||
| if not image_data: | ||
| logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") | ||
| return None | ||
| image_payload = {"url": image_data} | ||
|
|
||
| if image_detail: | ||
| image_payload["detail"] = image_detail | ||
| return { | ||
| "type": "image_url", | ||
| "image_url": image_payload, | ||
| } | ||
|
|
||
| def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]: | ||
| if not isinstance(part, dict) or part.get("type") != "image_url": | ||
| return None, None | ||
|
|
||
| image_url_data = part.get("image_url") | ||
| if not isinstance(image_url_data, dict): | ||
| logger.warning("图片内容块格式无效,将保留原始内容。") | ||
| return None, None | ||
|
|
||
| url = image_url_data.get("url") | ||
| if not isinstance(url, str) or not url: | ||
| logger.warning("图片内容块缺少有效 URL,将保留原始内容。") | ||
| return None, None | ||
|
|
||
| image_detail = image_url_data.get("detail") | ||
| if not isinstance(image_detail, str): | ||
| image_detail = None | ||
| return url, image_detail | ||
|
|
||
| async def _materialize_context_image_parts( | ||
| self, context_query: list[dict] | ||
| ) -> list[dict]: | ||
| new_messages: list[dict] = [] | ||
| for message in context_query: | ||
| content = message.get("content") | ||
| if not isinstance(content, list): | ||
| new_messages.append(copy.deepcopy(message)) | ||
| continue | ||
|
|
||
| new_content: list[dict] = [] | ||
| for part in content: | ||
| url, image_detail = self._extract_image_part_info(part) | ||
| if not url: | ||
| new_content.append(copy.deepcopy(part)) | ||
| continue | ||
|
|
||
| try: | ||
| resolved_part = await self._resolve_image_part( | ||
| url, image_detail=image_detail | ||
| ) | ||
| except Exception as exc: | ||
| logger.warning( | ||
| "图片 %s 预处理失败,将保留原始内容。错误: %s", | ||
| url, | ||
| exc, | ||
| ) | ||
| new_content.append(copy.deepcopy(part)) | ||
| continue | ||
|
|
||
| if resolved_part is None: | ||
| new_content.append(copy.deepcopy(part)) | ||
| continue | ||
|
|
||
| new_content.append(resolved_part) | ||
| new_message = copy.deepcopy(message) | ||
| new_message["content"] = new_content | ||
| new_messages.append(new_message) | ||
|
|
||
| return new_messages | ||
|
|
||
| async def _fallback_to_text_only_and_retry( | ||
| self, | ||
| payloads: dict, | ||
|
|
@@ -594,7 +783,7 @@ async def _prepare_chat_payload( | |
| new_record = await self.assemble_context( | ||
| prompt, image_urls, extra_user_content_parts | ||
| ) | ||
| context_query = self._ensure_message_to_dicts(contexts) | ||
| context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts)) | ||
| if new_record: | ||
| context_query.append(new_record) | ||
| if system_prompt: | ||
|
|
@@ -612,6 +801,9 @@ async def _prepare_chat_payload( | |
| for tcr in tool_calls_result: | ||
| context_query.extend(tcr.to_openai_messages()) | ||
|
|
||
| if self._context_contains_image(context_query): | ||
| context_query = await self._materialize_context_image_parts(context_query) | ||
|
|
||
| model = model or self.get_model() | ||
|
|
||
| payloads = {"messages": context_query, "model": model} | ||
|
|
@@ -712,6 +904,18 @@ async def _handle_api_error( | |
| "image_content_moderated", | ||
| image_fallback_used=True, | ||
| ) | ||
| if self._is_invalid_attachment_error(e): | ||
| if image_fallback_used or not self._context_contains_image(context_query): | ||
| raise e | ||
| return await self._fallback_to_text_only_and_retry( | ||
| payloads, | ||
| context_query, | ||
| chosen_key, | ||
| available_api_keys, | ||
| func_tool, | ||
| "invalid_attachment", | ||
| image_fallback_used=True, | ||
| ) | ||
|
|
||
| if ( | ||
| "Function calling is not enabled" in str(e) | ||
|
|
@@ -913,23 +1117,6 @@ async def assemble_context( | |
| ) -> dict: | ||
| """组装成符合 OpenAI 格式的 role 为 user 的消息段""" | ||
|
|
||
| async def resolve_image_part(image_url: str) -> dict | None: | ||
| if image_url.startswith("http"): | ||
| image_path = await download_image_by_url(image_url) | ||
| image_data = await self.encode_image_bs64(image_path) | ||
| elif image_url.startswith("file:///"): | ||
| image_path = image_url.replace("file:///", "") | ||
| image_data = await self.encode_image_bs64(image_path) | ||
| else: | ||
| image_data = await self.encode_image_bs64(image_url) | ||
| if not image_data: | ||
| logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") | ||
| return None | ||
| return { | ||
| "type": "image_url", | ||
| "image_url": {"url": image_data}, | ||
| } | ||
|
|
||
| # 构建内容块列表 | ||
| content_blocks = [] | ||
|
|
||
|
|
@@ -949,7 +1136,9 @@ async def resolve_image_part(image_url: str) -> dict | None: | |
| if isinstance(part, TextPart): | ||
| content_blocks.append({"type": "text", "text": part.text}) | ||
| elif isinstance(part, ImageURLPart): | ||
| image_part = await resolve_image_part(part.image_url.url) | ||
| image_part = await self._resolve_image_part( | ||
| part.image_url.url, | ||
| ) | ||
| if image_part: | ||
| content_blocks.append(image_part) | ||
| else: | ||
|
|
@@ -958,7 +1147,7 @@ async def resolve_image_part(image_url: str) -> dict | None: | |
| # 3. 图片内容 | ||
| if image_urls: | ||
| for image_url in image_urls: | ||
| image_part = await resolve_image_part(image_url) | ||
| image_part = await self._resolve_image_part(image_url) | ||
| if image_part: | ||
| content_blocks.append(image_part) | ||
|
|
||
|
|
@@ -977,11 +1166,10 @@ async def resolve_image_part(image_url: str) -> dict | None: | |
|
|
||
| async def encode_image_bs64(self, image_url: str) -> str: | ||
| """将图片转换为 base64""" | ||
| if image_url.startswith("base64://"): | ||
| return image_url.replace("base64://", "data:image/jpeg;base64,") | ||
| with open(image_url, "rb") as f: | ||
| image_bs64 = base64.b64encode(f.read()).decode("utf-8") | ||
| return "data:image/jpeg;base64," + image_bs64 | ||
| image_data = await self._image_ref_to_data_url(image_url, mode="strict") | ||
| if image_data is None: | ||
| raise RuntimeError(f"Failed to encode image data: {image_url}") | ||
| return image_data | ||
|
|
||
| async def terminate(self): | ||
| if self.client: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.