Skip to content
Closed
33 changes: 13 additions & 20 deletions astrbot/core/provider/entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import base64
import enum
import json
from dataclasses import dataclass, field
Expand All @@ -11,7 +10,6 @@
from openai.types.chat.chat_completion import ChatCompletion

import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.agent.message import (
AssistantMessageSegment,
ContentPart,
Expand All @@ -21,7 +19,11 @@
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.io import (
download_image_by_url,
image_source_to_data_uri,
is_http_or_https_url,
)


class ProviderType(enum.Enum):
Expand Down Expand Up @@ -187,17 +189,12 @@ async def assemble_context(self) -> dict:
# 3. 图片内容
if self.image_urls:
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self._encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self._encode_image_bs64(image_path)
else:
image_data = await self._encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
image_source = (
await download_image_by_url(image_url)
if is_http_or_https_url(image_url)
else image_url
)
image_data = await self._encode_image_bs64(image_source)
content_blocks.append(
{"type": "image_url", "image_url": {"url": image_data}},
)
Expand All @@ -216,12 +213,8 @@ async def assemble_context(self) -> dict:

async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
data_uri, _ = image_source_to_data_uri(image_url)
return data_uri


@dataclass
Expand Down
31 changes: 13 additions & 18 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import base64
import inspect
import json
import random
Expand All @@ -22,7 +21,11 @@
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.io import (
download_image_by_url,
image_source_to_data_uri,
is_http_or_https_url,
)
from astrbot.core.utils.network_utils import (
create_proxy_client,
is_connection_error,
Expand Down Expand Up @@ -924,17 +927,12 @@ async def assemble_context(
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""

async def resolve_image_part(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
image_source = (
await download_image_by_url(image_url)
if is_http_or_https_url(image_url)
else image_url
)
image_data = await self.encode_image_bs64(image_source)
return {
"type": "image_url",
"image_url": {"url": image_data},
Expand Down Expand Up @@ -987,11 +985,8 @@ async def resolve_image_part(image_url: str) -> dict | None:

async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
data_uri, _ = image_source_to_data_uri(image_url)
return data_uri

async def terminate(self):
if self.client:
Expand Down
76 changes: 76 additions & 0 deletions astrbot/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import uuid
import zipfile
from pathlib import Path
from urllib.parse import unquote, urlsplit
from urllib.request import url2pathname

import aiohttp
import certifi
Expand Down Expand Up @@ -206,6 +208,80 @@ def file_to_base64(file_path: str) -> str:
return "base64://" + base64_str


DEFAULT_IMAGE_MIME_TYPE = "image/jpeg"


def is_http_or_https_url(source: str) -> bool:
"""Return whether source is a HTTP(S) URL (case-insensitive)."""
return urlsplit(source).scheme.lower() in ("http", "https")


def detect_image_mime_type(data: bytes) -> str:
"""根据图片二进制数据的 magic bytes 检测 MIME 类型。"""
if data[:8] == b"\x89PNG\r\n\x1a\n":
return "image/png"
if data[:2] == b"\xff\xd8":
return DEFAULT_IMAGE_MIME_TYPE
if data[:6] in (b"GIF87a", b"GIF89a"):
return "image/gif"
if data[:4] == b"RIFF" and data[8:12] == b"WEBP":
return "image/webp"
return DEFAULT_IMAGE_MIME_TYPE


def image_source_to_data_uri(image_source: str) -> tuple[str, str]:
"""将本地/内联图片来源统一转换为 data URI,并尽量保留真实 MIME 类型。

说明:
- 支持 `data:image/...`、`base64://...`、本地路径和 `file://...`。
- 不支持远程 URL(`http://`、`https://`),调用方应先下载到本地文件。
"""
lower_source = image_source.lower()

if lower_source.startswith("data:"):
prefix = image_source.split(",", 1)[0]
mime_type = prefix.split(";", 1)[0].removeprefix("data:").lower()
if not mime_type.startswith("image/"):
raise ValueError(
f"Only image data URI is supported, got MIME type: {mime_type or 'unknown'}",
)
return image_source, mime_type

if is_http_or_https_url(image_source):
raise ValueError(
"Remote image URL is not supported in image_source_to_data_uri; download the file before calling this helper.",
)

if image_source.startswith("base64://"):
raw_base64 = image_source.removeprefix("base64://")
mime_type = DEFAULT_IMAGE_MIME_TYPE
try:
image_bytes = base64.b64decode(raw_base64)
mime_type = detect_image_mime_type(image_bytes)
except Exception:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

在处理 base64.b64decode 异常时,捕获更具体的 base64.binascii.Error 而不是通用的 Exception 会更好。这有助于区分不同类型的错误,并使代码更具可读性和可维护性。虽然当前逻辑在功能上没有问题,但更具体的异常处理是良好的编程实践。

Suggested change
except Exception:
pass
image_bytes = base64.b64decode(raw_base64)
mime_type = detect_image_mime_type(image_bytes)
except base64.binascii.Error:
pass

return f"data:{mime_type};base64,{raw_base64}", mime_type

if lower_source.startswith("file://"):
parsed = urlsplit(image_source)
if parsed.netloc and parsed.netloc != "localhost":
raw_path = f"//{parsed.netloc}{parsed.path}"
else:
raw_path = parsed.path
image_source = url2pathname(unquote(raw_path))
elif "://" in image_source:
scheme = image_source.split("://", 1)[0].lower()
raise ValueError(
f"Unsupported image source scheme: {scheme}://",
)

with open(image_source, "rb") as f:
image_bytes = f.read()
mime_type = detect_image_mime_type(image_bytes)
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:{mime_type};base64,{image_bs64}", mime_type


def get_local_ip_addresses():
net_interfaces = psutil.net_if_addrs()
network_ips = []
Expand Down
11 changes: 11 additions & 0 deletions tests/fixtures/image_samples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import base64

PNG_BYTES = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR"
GIF_BYTES = b"GIF89a\x01\x00\x01\x00\x80\x00\x00"
WEBP_BYTES = b"RIFF\x0c\x00\x00\x00WEBPVP8 "
JPEG_BYTES = b"\xff\xd8\xff\xe0\x00\x10JFIF"

PNG_BASE64 = base64.b64encode(PNG_BYTES).decode("ascii")
GIF_BASE64 = base64.b64encode(GIF_BYTES).decode("ascii")
WEBP_BASE64 = base64.b64encode(WEBP_BYTES).decode("ascii")
JPEG_BASE64 = base64.b64encode(JPEG_BYTES).decode("ascii")
151 changes: 151 additions & 0 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from pathlib import Path
from types import SimpleNamespace

import pytest
from openai.types.chat.chat_completion import ChatCompletion

from astrbot.core.agent.message import ImageURLPart
from astrbot.core.provider.sources.groq_source import ProviderGroq
from astrbot.core.provider.sources.openai_source import ProviderOpenAIOfficial
from tests.fixtures.image_samples import (
GIF_BASE64,
GIF_BYTES,
JPEG_BASE64,
JPEG_BYTES,
PNG_BASE64,
PNG_BYTES,
WEBP_BASE64,
WEBP_BYTES,
)


class _ErrorWithBody(Exception):
Expand Down Expand Up @@ -533,3 +545,142 @@ async def fake_create(**kwargs):
assert extra_body["temperature"] == 0.1
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_encode_image_bs64_detects_base64_mime():
provider = _make_provider()
try:
png_data = await provider.encode_image_bs64(f"base64://{PNG_BASE64}")
gif_data = await provider.encode_image_bs64(f"base64://{GIF_BASE64}")
webp_data = await provider.encode_image_bs64(f"base64://{WEBP_BASE64}")

assert png_data.startswith("data:image/png;base64,")
assert gif_data.startswith("data:image/gif;base64,")
assert webp_data.startswith("data:image/webp;base64,")
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_encode_image_bs64_detects_local_file_mime(tmp_path: Path):
provider = _make_provider()
png_path = tmp_path / "pixel.png"
gif_path = tmp_path / "pixel.gif"
jpeg_path = tmp_path / "pixel.jpg"
webp_path = tmp_path / "pixel.webp"
png_path.write_bytes(PNG_BYTES)
gif_path.write_bytes(GIF_BYTES)
jpeg_path.write_bytes(JPEG_BYTES)
webp_path.write_bytes(WEBP_BYTES)
try:
png_data = await provider.encode_image_bs64(str(png_path))
gif_data = await provider.encode_image_bs64(str(gif_path))
jpeg_data = await provider.encode_image_bs64(str(jpeg_path))
webp_data = await provider.encode_image_bs64(str(webp_path))

assert png_data.startswith("data:image/png;base64,")
assert gif_data.startswith("data:image/gif;base64,")
assert jpeg_data.startswith("data:image/jpeg;base64,")
assert webp_data.startswith("data:image/webp;base64,")
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_encode_image_bs64_keeps_data_uri():
provider = _make_provider()
png_data_uri = f"data:image/png;base64,{PNG_BASE64}"
gif_data_uri = f"data:image/gif;base64,{GIF_BASE64}"
jpeg_data_uri = f"data:image/jpeg;base64,{JPEG_BASE64}"
try:
assert await provider.encode_image_bs64(png_data_uri) == png_data_uri
assert await provider.encode_image_bs64(gif_data_uri) == gif_data_uri
assert await provider.encode_image_bs64(jpeg_data_uri) == jpeg_data_uri
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_encode_image_bs64_invalid_base64_fallback_to_jpeg():
provider = _make_provider()
try:
image_data = await provider.encode_image_bs64("base64://not-valid-base64")
assert image_data == "data:image/jpeg;base64,not-valid-base64"
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_encode_image_bs64_rejects_non_image_data_uri():
provider = _make_provider()
try:
with pytest.raises(ValueError, match="Only image data URI is supported"):
await provider.encode_image_bs64("data:text/plain;base64,SGVsbG8=")
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_encode_image_bs64_rejects_unsupported_uri_scheme():
provider = _make_provider()
try:
with pytest.raises(ValueError, match="Unsupported image source scheme"):
await provider.encode_image_bs64("s3://bucket/path/image.png")
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_assemble_context_extra_image_file_uri_mime(tmp_path: Path):
provider = _make_provider()
png_path = tmp_path / "agent-request.png"
png_path.write_bytes(PNG_BYTES)
try:
assembled = await provider.assemble_context(
text="hello",
extra_user_content_parts=[
ImageURLPart(
image_url=ImageURLPart.ImageURL(
url=f"file:///{png_path.as_posix()}",
)
)
],
)

assert isinstance(assembled["content"], list)
image_part = assembled["content"][1]
assert image_part["type"] == "image_url"
assert image_part["image_url"]["url"].startswith("data:image/png;base64,")
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_openai_assemble_context_uppercase_https_image_url(
tmp_path: Path, monkeypatch
):
provider = _make_provider()
png_path = tmp_path / "remote.png"
png_path.write_bytes(PNG_BYTES)

async def fake_download(url: str) -> str:
assert url == "HTTPS://example.com/asset.png"
return str(png_path)

monkeypatch.setattr(
"astrbot.core.provider.sources.openai_source.download_image_by_url",
fake_download,
)
try:
assembled = await provider.assemble_context(
text="hello",
image_urls=["HTTPS://example.com/asset.png"],
)

image_part = next(
part for part in assembled["content"] if part.get("type") == "image_url"
)
assert image_part["image_url"]["url"].startswith("data:image/png;base64,")
finally:
await provider.terminate()
Loading
Loading