Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 220 additions & 23 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import asyncio
import base64
import copy
import inspect
import json
import random
import re
from collections.abc import AsyncGenerator
from io import BytesIO
from pathlib import Path
from typing import Any
from urllib.parse import unquote, urlparse

import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI
Expand All @@ -14,6 +18,8 @@
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion_usage import CompletionUsage
from PIL import Image as PILImage
from PIL import UnidentifiedImageError

import astrbot.core.message.components as Comp
from astrbot import logger
Expand Down Expand Up @@ -133,6 +139,191 @@ def _context_contains_image(contexts: list[dict]) -> bool:
return True
return False

@staticmethod
def _get_error_info(error: Exception) -> tuple[str | None, str | None]:
body = getattr(error, "body", None)
if not isinstance(body, dict):
return None, None

err_obj = body.get("error")
if not isinstance(err_obj, dict):
return None, None

code = err_obj.get("code")
message = err_obj.get("message")
return (
code.lower() if isinstance(code, str) else None,
message.lower() if isinstance(message, str) else None,
)

def _is_invalid_attachment_error(self, error: Exception) -> bool:
code, message = self._get_error_info(error)

candidates = [
candidate.lower()
for candidate in self._extract_error_text_candidates(error)
]
if message:
candidates.append(message)

for candidate in candidates:
if "invalid_attachment" in candidate:
return True
if "download attachment" in candidate and "404" in candidate:
return True
return code == "invalid_attachment"

@staticmethod
def _image_format_to_mime_type(image_format: str | None) -> str:
return {
"JPEG": "image/jpeg",
"PNG": "image/png",
"GIF": "image/gif",
"WEBP": "image/webp",
"BMP": "image/bmp",
}.get(str(image_format or "").upper(), "image/jpeg")

@staticmethod
def _read_file_bytes(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
image_path: str,
*,
suppress_errors: bool = True,
) -> bytes | None:
try:
return Path(image_path).read_bytes()
except OSError:
if not suppress_errors:
raise
return None

@staticmethod
def _detect_image_format(image_bytes: bytes) -> str | None:
try:
with PILImage.open(BytesIO(image_bytes)) as image:
image.verify()
return str(image.format or "").upper()
except (OSError, UnidentifiedImageError):
return None

@classmethod
def _encode_image_file_to_data_url(cls, image_path: str) -> str | None:
image_bytes = cls._read_file_bytes(image_path)
if image_bytes is None:
return None

image_format = cls._detect_image_format(image_bytes)
if image_format is None:
return None

mime_type = cls._image_format_to_mime_type(image_format)
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}"

Comment on lines +195 to +203
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _detect_image_mime_type method uses a limited dictionary to map image formats to MIME types and defaults to image/jpeg for unrecognized formats. While _is_valid_image_file ensures the file is an image, using a more comprehensive approach like PILImage.MIME would provide more accurate MIME types for a wider range of image formats, improving correctness and compatibility with APIs that might be strict about MIME types. For example, a TIFF image would currently be incorrectly identified as image/jpeg.

        return PILImage.MIME.get(image_format, "image/jpeg")

@staticmethod
def _file_uri_to_path(file_uri: str) -> str:
parsed = urlparse(file_uri)
if parsed.scheme != "file":
return file_uri

netloc = unquote(parsed.netloc or "")
path = unquote(parsed.path or "")
if re.fullmatch(r"[A-Za-z]:", netloc):
return str(Path(f"{netloc}{path}"))
if re.match(r"^/[A-Za-z]:/", path):
path = path[1:]
if netloc and netloc != "localhost":
path = f"//{netloc}{path}"
return str(Path(path))

async def _load_image_data(self, image_url: str) -> str | None:
if image_url.startswith("base64://"):
return await self.encode_image_bs64(image_url)

if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
elif image_url.startswith("file://"):
image_path = self._file_uri_to_path(image_url)
else:
image_path = image_url

return self._encode_image_file_to_data_url(image_path)

async def _resolve_image_part(
self,
image_url: str,
*,
image_detail: str | None = None,
) -> dict | None:
if image_url.startswith("data:"):
image_payload = {"url": image_url}
else:
image_data = await self._load_image_data(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
image_payload = {"url": image_data}

if image_detail:
image_payload["detail"] = image_detail
return {
"type": "image_url",
"image_url": image_payload,
}

async def _materialize_context_image_parts(self, context_query: list[dict]) -> None:
for message in context_query:
content = message.get("content")
if not isinstance(content, list):
continue

new_content: list[dict] = []
content_changed = False
for part in content:
if not isinstance(part, dict) or part.get("type") != "image_url":
new_content.append(part)
continue

image_url_data = part.get("image_url")
if not isinstance(image_url_data, dict):
logger.warning("图片内容块格式无效,将保留原始内容。")
new_content.append(part)
continue

url = image_url_data.get("url")
if not isinstance(url, str) or not url:
logger.warning("图片内容块缺少有效 URL,将保留原始内容。")
new_content.append(part)
continue

image_detail = image_url_data.get("detail")
if not isinstance(image_detail, str):
image_detail = None

try:
resolved_part = await self._resolve_image_part(
url, image_detail=image_detail
)
except Exception as exc:
logger.warning(
"图片 %s 预处理失败,将保留原始内容。错误: %s",
url,
exc,
)
new_content.append(part)
continue

if resolved_part is None:
new_content.append(part)
continue

new_content.append(resolved_part)
if resolved_part != part:
content_changed = True

if not content_changed:
continue
message["content"] = new_content

async def _fallback_to_text_only_and_retry(
self,
payloads: dict,
Expand Down Expand Up @@ -594,7 +785,7 @@ async def _prepare_chat_payload(
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
context_query = copy.deepcopy(self._ensure_message_to_dicts(contexts))
if new_record:
context_query.append(new_record)
if system_prompt:
Expand All @@ -612,6 +803,8 @@ async def _prepare_chat_payload(
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())

await self._materialize_context_image_parts(context_query)

model = model or self.get_model()

payloads = {"messages": context_query, "model": model}
Expand Down Expand Up @@ -712,6 +905,18 @@ async def _handle_api_error(
"image_content_moderated",
image_fallback_used=True,
)
if self._is_invalid_attachment_error(e):
if image_fallback_used or not self._context_contains_image(context_query):
raise e
return await self._fallback_to_text_only_and_retry(
payloads,
context_query,
chosen_key,
available_api_keys,
func_tool,
"invalid_attachment",
image_fallback_used=True,
)

if (
"Function calling is not enabled" in str(e)
Expand Down Expand Up @@ -913,23 +1118,6 @@ async def assemble_context(
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""

async def resolve_image_part(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
return {
"type": "image_url",
"image_url": {"url": image_data},
}

# 构建内容块列表
content_blocks = []

Expand All @@ -949,7 +1137,9 @@ async def resolve_image_part(image_url: str) -> dict | None:
if isinstance(part, TextPart):
content_blocks.append({"type": "text", "text": part.text})
elif isinstance(part, ImageURLPart):
image_part = await resolve_image_part(part.image_url.url)
image_part = await self._resolve_image_part(
part.image_url.url,
)
if image_part:
content_blocks.append(image_part)
else:
Expand All @@ -958,7 +1148,7 @@ async def resolve_image_part(image_url: str) -> dict | None:
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_part = await resolve_image_part(image_url)
image_part = await self._resolve_image_part(image_url)
if image_part:
content_blocks.append(image_part)

Expand All @@ -979,9 +1169,16 @@ async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
image_bytes = self._read_file_bytes(image_url, suppress_errors=False)
if image_bytes is None:
raise FileNotFoundError(image_url)

Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
image_format = self._detect_image_format(image_bytes)
if image_format is None:
raise ValueError(f"Invalid image file: {image_url}")
mime_type = self._image_format_to_mime_type(image_format)
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}"

async def terminate(self):
if self.client:
Expand Down
Loading
Loading