Skip to content

Commit d21b35f

Browse files
committed
refactor: simplify OpenAI helper flow
1 parent 3df5603 commit d21b35f

1 file changed

Lines changed: 42 additions & 45 deletions

File tree

astrbot/core/provider/sources/openai_source.py

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def _context_contains_image(contexts: list[dict]) -> bool:
141141

142142
def _is_invalid_attachment_error(self, error: Exception) -> bool:
143143
body = getattr(error, "body", None)
144-
code = None
145-
message = None
144+
code: str | None = None
145+
message: str | None = None
146146
if isinstance(body, dict):
147147
err_obj = body.get("error")
148148
if isinstance(err_obj, dict):
@@ -151,30 +151,20 @@ def _is_invalid_attachment_error(self, error: Exception) -> bool:
151151
code = raw_code.lower() if isinstance(raw_code, str) else None
152152
message = raw_message.lower() if isinstance(raw_message, str) else None
153153

154-
candidates = [
155-
candidate.lower()
156-
for candidate in self._extract_error_text_candidates(error)
157-
]
154+
parts: list[str] = []
155+
if code:
156+
parts.append(code)
158157
if message:
159-
candidates.append(message)
160-
161-
for candidate in candidates:
162-
if "invalid_attachment" in candidate:
163-
return True
164-
if "download attachment" in candidate and "404" in candidate:
165-
return True
158+
parts.append(message)
159+
parts.extend(map(str, self._extract_error_text_candidates(error)))
160+
161+
error_text = " ".join(part.lower() for part in parts if part)
162+
if "invalid_attachment" in error_text:
163+
return True
164+
if "download attachment" in error_text and "404" in error_text:
165+
return True
166166
return code == "invalid_attachment"
167167

168-
@staticmethod
169-
def _image_format_to_mime_type(image_format: str | None) -> str:
170-
return {
171-
"JPEG": "image/jpeg",
172-
"PNG": "image/png",
173-
"GIF": "image/gif",
174-
"WEBP": "image/webp",
175-
"BMP": "image/bmp",
176-
}.get(str(image_format or "").upper(), "image/jpeg")
177-
178168
@classmethod
179169
def _encode_image_file_to_data_url(
180170
cls,
@@ -199,7 +189,13 @@ def _encode_image_file_to_data_url(
199189
raise ValueError(f"Invalid image file: {image_path}")
200190
return None
201191

202-
mime_type = cls._image_format_to_mime_type(image_format)
192+
mime_type = {
193+
"JPEG": "image/jpeg",
194+
"PNG": "image/png",
195+
"GIF": "image/gif",
196+
"WEBP": "image/webp",
197+
"BMP": "image/bmp",
198+
}.get(image_format, "image/jpeg")
203199
image_bs64 = base64.b64encode(image_bytes).decode("utf-8")
204200
return f"data:{mime_type};base64,{image_bs64}"
205201

@@ -224,11 +220,6 @@ def _file_uri_to_path(file_uri: str) -> str:
224220
path = f"//{netloc}{path}"
225221
return str(Path(path))
226222

227-
def _normalize_image_path(self, image_url: str) -> str:
228-
if image_url.startswith("file://"):
229-
return self._file_uri_to_path(image_url)
230-
return image_url
231-
232223
async def _image_ref_to_data_url(
233224
self,
234225
image_ref: str,
@@ -241,15 +232,34 @@ async def _image_ref_to_data_url(
241232

242233
if image_ref.startswith("http"):
243234
image_path = await download_image_by_url(image_ref)
235+
elif image_ref.startswith("file://"):
236+
image_path = self._file_uri_to_path(image_ref)
244237
else:
245-
image_path = self._normalize_image_path(image_ref)
238+
image_path = image_ref
246239

247240
return self._encode_image_file_to_data_url(
248241
image_path,
249242
suppress_errors=suppress_errors,
250243
raise_on_invalid_image=raise_on_invalid_image,
251244
)
252245

246+
async def _image_ref_to_data_url_safe(self, image_ref: str) -> str | None:
247+
return await self._image_ref_to_data_url(
248+
image_ref,
249+
suppress_errors=True,
250+
raise_on_invalid_image=False,
251+
)
252+
253+
async def _image_ref_to_data_url_strict(self, image_ref: str) -> str:
254+
image_data = await self._image_ref_to_data_url(
255+
image_ref,
256+
suppress_errors=False,
257+
raise_on_invalid_image=True,
258+
)
259+
if image_data is None:
260+
raise RuntimeError(f"Failed to encode image data: {image_ref}")
261+
return image_data
262+
253263
async def _resolve_image_part(
254264
self,
255265
image_url: str,
@@ -259,7 +269,7 @@ async def _resolve_image_part(
259269
if image_url.startswith("data:"):
260270
image_payload = {"url": image_url}
261271
else:
262-
image_data = await self._image_ref_to_data_url(image_url)
272+
image_data = await self._image_ref_to_data_url_safe(image_url)
263273
if not image_data:
264274
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
265275
return None
@@ -298,7 +308,6 @@ async def _materialize_context_image_parts(self, context_query: list[dict]) -> N
298308
continue
299309

300310
new_content: list[dict] = []
301-
content_changed = False
302311
for part in content:
303312
url, image_detail = self._extract_image_part_info(part)
304313
if not url:
@@ -323,11 +332,6 @@ async def _materialize_context_image_parts(self, context_query: list[dict]) -> N
323332
continue
324333

325334
new_content.append(resolved_part)
326-
if resolved_part != part:
327-
content_changed = True
328-
329-
if not content_changed:
330-
continue
331335
message["content"] = new_content
332336

333337
async def _fallback_to_text_only_and_retry(
@@ -1174,14 +1178,7 @@ async def assemble_context(
11741178

11751179
async def encode_image_bs64(self, image_url: str) -> str:
11761180
"""将图片转换为 base64"""
1177-
image_data = await self._image_ref_to_data_url(
1178-
image_url,
1179-
suppress_errors=False,
1180-
raise_on_invalid_image=True,
1181-
)
1182-
if image_data is None:
1183-
raise RuntimeError(f"Failed to encode image data: {image_url}")
1184-
return image_data
1181+
return await self._image_ref_to_data_url_strict(image_url)
11851182

11861183
async def terminate(self):
11871184
if self.client:

0 commit comments

Comments
 (0)