Skip to content

Commit 3df5603

Browse files
committed
refactor: simplify OpenAI attachment helpers
1 parent 69a4231 commit 3df5603

2 files changed

Lines changed: 141 additions & 78 deletions

File tree

astrbot/core/provider/sources/openai_source.py

Lines changed: 67 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import AsyncGenerator
99
from io import BytesIO
1010
from pathlib import Path
11-
from typing import Any, cast
11+
from typing import Any
1212
from urllib.parse import unquote, urlparse
1313

1414
import httpx
@@ -139,25 +139,17 @@ def _context_contains_image(contexts: list[dict]) -> bool:
139139
return True
140140
return False
141141

142-
@staticmethod
143-
def _get_error_info(error: Exception) -> tuple[str | None, str | None]:
144-
body = getattr(error, "body", None)
145-
if not isinstance(body, dict):
146-
return None, None
147-
148-
err_obj = body.get("error")
149-
if not isinstance(err_obj, dict):
150-
return None, None
151-
152-
code = err_obj.get("code")
153-
message = err_obj.get("message")
154-
return (
155-
code.lower() if isinstance(code, str) else None,
156-
message.lower() if isinstance(message, str) else None,
157-
)
158-
159142
def _is_invalid_attachment_error(self, error: Exception) -> bool:
160-
code, message = self._get_error_info(error)
143+
body = getattr(error, "body", None)
144+
code = None
145+
message = None
146+
if isinstance(body, dict):
147+
err_obj = body.get("error")
148+
if isinstance(err_obj, dict):
149+
raw_code = err_obj.get("code")
150+
raw_message = err_obj.get("message")
151+
code = raw_code.lower() if isinstance(raw_code, str) else None
152+
message = raw_message.lower() if isinstance(raw_message, str) else None
161153

162154
candidates = [
163155
candidate.lower()
@@ -183,45 +175,26 @@ def _image_format_to_mime_type(image_format: str | None) -> str:
183175
"BMP": "image/bmp",
184176
}.get(str(image_format or "").upper(), "image/jpeg")
185177

186-
@staticmethod
187-
def _read_file_bytes(
178+
@classmethod
179+
def _encode_image_file_to_data_url(
180+
cls,
188181
image_path: str,
189182
*,
190183
suppress_errors: bool = True,
191-
) -> bytes | None:
184+
raise_on_invalid_image: bool = False,
185+
) -> str | None:
192186
try:
193-
return Path(image_path).read_bytes()
187+
image_bytes = Path(image_path).read_bytes()
194188
except OSError:
195189
if not suppress_errors:
196190
raise
197191
return None
198192

199-
@staticmethod
200-
def _detect_image_format(image_bytes: bytes) -> str | None:
201193
try:
202194
with PILImage.open(BytesIO(image_bytes)) as image:
203195
image.verify()
204-
return str(image.format or "").upper()
196+
image_format = str(image.format or "").upper()
205197
except (OSError, UnidentifiedImageError):
206-
return None
207-
208-
@classmethod
209-
def _encode_image_file_to_data_url(
210-
cls,
211-
image_path: str,
212-
*,
213-
suppress_errors: bool = True,
214-
raise_on_invalid_image: bool = False,
215-
) -> str | None:
216-
image_bytes = cls._read_file_bytes(
217-
image_path,
218-
suppress_errors=suppress_errors,
219-
)
220-
if image_bytes is None:
221-
return None
222-
223-
image_format = cls._detect_image_format(image_bytes)
224-
if image_format is None:
225198
if raise_on_invalid_image:
226199
raise ValueError(f"Invalid image file: {image_path}")
227200
return None
@@ -232,6 +205,11 @@ def _encode_image_file_to_data_url(
232205

233206
@staticmethod
234207
def _file_uri_to_path(file_uri: str) -> str:
208+
"""Normalize file URIs to paths.
209+
210+
`file://localhost/...` and drive-letter forms are treated as local paths.
211+
Other non-empty hosts are preserved as UNC-style paths.
212+
"""
235213
parsed = urlparse(file_uri)
236214
if parsed.scheme != "file":
237215
return file_uri
@@ -251,16 +229,26 @@ def _normalize_image_path(self, image_url: str) -> str:
251229
return self._file_uri_to_path(image_url)
252230
return image_url
253231

254-
async def _load_image_data(self, image_url: str) -> str | None:
255-
if image_url.startswith("base64://"):
256-
return await self.encode_image_bs64(image_url)
232+
async def _image_ref_to_data_url(
233+
self,
234+
image_ref: str,
235+
*,
236+
suppress_errors: bool = True,
237+
raise_on_invalid_image: bool = False,
238+
) -> str | None:
239+
if image_ref.startswith("base64://"):
240+
return image_ref.replace("base64://", "data:image/jpeg;base64,")
257241

258-
if image_url.startswith("http"):
259-
image_path = await download_image_by_url(image_url)
242+
if image_ref.startswith("http"):
243+
image_path = await download_image_by_url(image_ref)
260244
else:
261-
image_path = self._normalize_image_path(image_url)
245+
image_path = self._normalize_image_path(image_ref)
262246

263-
return self._encode_image_file_to_data_url(image_path)
247+
return self._encode_image_file_to_data_url(
248+
image_path,
249+
suppress_errors=suppress_errors,
250+
raise_on_invalid_image=raise_on_invalid_image,
251+
)
264252

265253
async def _resolve_image_part(
266254
self,
@@ -271,7 +259,7 @@ async def _resolve_image_part(
271259
if image_url.startswith("data:"):
272260
image_payload = {"url": image_url}
273261
else:
274-
image_data = await self._load_image_data(image_url)
262+
image_data = await self._image_ref_to_data_url(image_url)
275263
if not image_data:
276264
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
277265
return None
@@ -284,6 +272,25 @@ async def _resolve_image_part(
284272
"image_url": image_payload,
285273
}
286274

275+
def _extract_image_part_info(self, part: dict) -> tuple[str | None, str | None]:
276+
if not isinstance(part, dict) or part.get("type") != "image_url":
277+
return None, None
278+
279+
image_url_data = part.get("image_url")
280+
if not isinstance(image_url_data, dict):
281+
logger.warning("图片内容块格式无效,将保留原始内容。")
282+
return None, None
283+
284+
url = image_url_data.get("url")
285+
if not isinstance(url, str) or not url:
286+
logger.warning("图片内容块缺少有效 URL,将保留原始内容。")
287+
return None, None
288+
289+
image_detail = image_url_data.get("detail")
290+
if not isinstance(image_detail, str):
291+
image_detail = None
292+
return url, image_detail
293+
287294
async def _materialize_context_image_parts(self, context_query: list[dict]) -> None:
288295
for message in context_query:
289296
content = message.get("content")
@@ -293,26 +300,11 @@ async def _materialize_context_image_parts(self, context_query: list[dict]) -> N
293300
new_content: list[dict] = []
294301
content_changed = False
295302
for part in content:
296-
if not isinstance(part, dict) or part.get("type") != "image_url":
303+
url, image_detail = self._extract_image_part_info(part)
304+
if not url:
297305
new_content.append(part)
298306
continue
299307

300-
image_url_data = part.get("image_url")
301-
if not isinstance(image_url_data, dict):
302-
logger.warning("图片内容块格式无效,将保留原始内容。")
303-
new_content.append(part)
304-
continue
305-
306-
url = image_url_data.get("url")
307-
if not isinstance(url, str) or not url:
308-
logger.warning("图片内容块缺少有效 URL,将保留原始内容。")
309-
new_content.append(part)
310-
continue
311-
312-
image_detail = image_url_data.get("detail")
313-
if not isinstance(image_detail, str):
314-
image_detail = None
315-
316308
try:
317309
resolved_part = await self._resolve_image_part(
318310
url, image_detail=image_detail
@@ -1182,15 +1174,14 @@ async def assemble_context(
11821174

11831175
async def encode_image_bs64(self, image_url: str) -> str:
11841176
"""将图片转换为 base64"""
1185-
if image_url.startswith("base64://"):
1186-
return image_url.replace("base64://", "data:image/jpeg;base64,")
1187-
image_path = self._normalize_image_path(image_url)
1188-
image_data = self._encode_image_file_to_data_url(
1189-
image_path,
1177+
image_data = await self._image_ref_to_data_url(
1178+
image_url,
11901179
suppress_errors=False,
11911180
raise_on_invalid_image=True,
11921181
)
1193-
return cast(str, image_data)
1182+
if image_data is None:
1183+
raise RuntimeError(f"Failed to encode image data: {image_url}")
1184+
return image_data
11941185

11951186
async def terminate(self):
11961187
if self.client:

tests/test_openai_source.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ async def fake_download(url: str) -> str:
593593
assert url == "https://example.com/quoted.png"
594594
return "/tmp/quoted.png"
595595

596-
def fake_encode(image_path: str) -> str:
596+
def fake_encode(image_path: str, **_kwargs) -> str:
597597
assert image_path == "/tmp/quoted.png"
598598
return "data:image/png;base64,abcd"
599599

@@ -669,6 +669,40 @@ async def fail_if_called(_context_query):
669669
await provider.terminate()
670670

671671

672+
@pytest.mark.asyncio
673+
async def test_prepare_chat_payload_skips_materialization_for_text_only_parts(
674+
monkeypatch,
675+
):
676+
provider = _make_provider()
677+
try:
678+
679+
async def fail_if_called(_context_query):
680+
raise AssertionError("materialization should be skipped")
681+
682+
monkeypatch.setattr(
683+
provider, "_materialize_context_image_parts", fail_if_called
684+
)
685+
686+
payloads, _ = await provider._prepare_chat_payload(
687+
prompt=None,
688+
contexts=[
689+
{
690+
"role": "user",
691+
"content": [{"type": "text", "text": "hello"}],
692+
}
693+
],
694+
)
695+
696+
assert payloads["messages"] == [
697+
{
698+
"role": "user",
699+
"content": [{"type": "text", "text": "hello"}],
700+
}
701+
]
702+
finally:
703+
await provider.terminate()
704+
705+
672706
@pytest.mark.asyncio
673707
async def test_prepare_chat_payload_materializes_context_http_image_urls_with_detected_mime(
674708
monkeypatch, tmp_path
@@ -764,6 +798,17 @@ async def test_file_uri_to_path_preserves_windows_netloc_drive_letter():
764798
await provider.terminate()
765799

766800

801+
@pytest.mark.asyncio
802+
async def test_file_uri_to_path_preserves_remote_netloc_as_unc_path():
803+
provider = _make_provider()
804+
try:
805+
assert provider._file_uri_to_path("file://server/share/quoted-image.png") == (
806+
"//server/share/quoted-image.png"
807+
)
808+
finally:
809+
await provider.terminate()
810+
811+
767812
@pytest.mark.asyncio
768813
async def test_resolve_image_part_rejects_invalid_local_file(tmp_path):
769814
provider = _make_provider()
@@ -812,6 +857,17 @@ async def test_encode_image_bs64_invalid_file_raises(tmp_path):
812857
await provider.terminate()
813858

814859

860+
@pytest.mark.asyncio
861+
async def test_encode_image_bs64_supports_base64_scheme():
862+
provider = _make_provider()
863+
try:
864+
image_data = await provider.encode_image_bs64("base64://abcd")
865+
866+
assert image_data == "data:image/jpeg;base64,abcd"
867+
finally:
868+
await provider.terminate()
869+
870+
815871
@pytest.mark.asyncio
816872
async def test_encode_image_bs64_supports_file_uri(tmp_path):
817873
provider = _make_provider()
@@ -826,6 +882,18 @@ async def test_encode_image_bs64_supports_file_uri(tmp_path):
826882
await provider.terminate()
827883

828884

885+
@pytest.mark.asyncio
886+
async def test_resolve_image_part_supports_base64_scheme():
887+
provider = _make_provider()
888+
try:
889+
assert await provider._resolve_image_part("base64://abcd") == {
890+
"type": "image_url",
891+
"image_url": {"url": "data:image/jpeg;base64,abcd"},
892+
}
893+
finally:
894+
await provider.terminate()
895+
896+
829897
@pytest.mark.asyncio
830898
async def test_prepare_chat_payload_materializes_context_localhost_file_uri_image_urls(
831899
tmp_path,
@@ -875,7 +943,11 @@ async def fake_download(url: str) -> str:
875943
"astrbot.core.provider.sources.openai_source.download_image_by_url",
876944
fake_download,
877945
)
878-
monkeypatch.setattr(provider, "_encode_image_file_to_data_url", lambda _: None)
946+
monkeypatch.setattr(
947+
provider,
948+
"_encode_image_file_to_data_url",
949+
lambda _image_path, **_kwargs: None,
950+
)
879951

880952
payloads, _ = await provider._prepare_chat_payload(
881953
prompt=None,

0 commit comments

Comments
 (0)