11import asyncio
22import base64
3+ import copy
34import inspect
45import json
56import random
67import re
78from collections .abc import AsyncGenerator
8- from typing import Any
9+ from io import BytesIO
10+ from pathlib import Path
11+ from typing import Any , Literal
12+ from urllib .parse import unquote , urlparse
913
1014import httpx
1115from openai import AsyncAzureOpenAI , AsyncOpenAI
1418from openai .types .chat .chat_completion import ChatCompletion
1519from openai .types .chat .chat_completion_chunk import ChatCompletionChunk
1620from openai .types .completion_usage import CompletionUsage
21+ from PIL import Image as PILImage
22+ from PIL import UnidentifiedImageError
1723
1824import astrbot .core .message .components as Comp
1925from astrbot import logger
@@ -133,6 +139,186 @@ def _context_contains_image(contexts: list[dict]) -> bool:
133139 return True
134140 return False
135141
142+ def _is_invalid_attachment_error (self , error : Exception ) -> bool :
143+ body = getattr (error , "body" , None )
144+ code : str | None = None
145+ message : str | None = None
146+ if isinstance (body , dict ):
147+ err_obj = body .get ("error" )
148+ if isinstance (err_obj , dict ):
149+ raw_code = err_obj .get ("code" )
150+ raw_message = err_obj .get ("message" )
151+ code = raw_code .lower () if isinstance (raw_code , str ) else None
152+ message = raw_message .lower () if isinstance (raw_message , str ) else None
153+
154+ if code == "invalid_attachment" :
155+ return True
156+
157+ text_sources : list [str ] = []
158+ if message :
159+ text_sources .append (message )
160+ if code :
161+ text_sources .append (code )
162+ text_sources .extend (map (str , self ._extract_error_text_candidates (error )))
163+
164+ error_text = " " .join (text .lower () for text in text_sources if text )
165+ if "invalid_attachment" in error_text :
166+ return True
167+ if "download attachment" in error_text and "404" in error_text :
168+ return True
169+ return False
170+
171+ @classmethod
172+ def _encode_image_file_to_data_url (
173+ cls ,
174+ image_path : str ,
175+ * ,
176+ mode : Literal ["safe" , "strict" ],
177+ ) -> str | None :
178+ try :
179+ image_bytes = Path (image_path ).read_bytes ()
180+ except OSError :
181+ if mode == "strict" :
182+ raise
183+ return None
184+
185+ try :
186+ with PILImage .open (BytesIO (image_bytes )) as image :
187+ image .verify ()
188+ image_format = str (image .format or "" ).upper ()
189+ except (OSError , UnidentifiedImageError ):
190+ if mode == "strict" :
191+ raise ValueError (f"Invalid image file: { image_path } " )
192+ return None
193+
194+ mime_type = {
195+ "JPEG" : "image/jpeg" ,
196+ "PNG" : "image/png" ,
197+ "GIF" : "image/gif" ,
198+ "WEBP" : "image/webp" ,
199+ "BMP" : "image/bmp" ,
200+ }.get (image_format , "image/jpeg" )
201+ image_bs64 = base64 .b64encode (image_bytes ).decode ("utf-8" )
202+ return f"data:{ mime_type } ;base64,{ image_bs64 } "
203+
204+ @staticmethod
205+ def _file_uri_to_path (file_uri : str ) -> str :
206+ """Normalize file URIs to paths.
207+
208+ `file://localhost/...` and drive-letter forms are treated as local paths.
209+ Other non-empty hosts are preserved as UNC-style paths.
210+ """
211+ parsed = urlparse (file_uri )
212+ if parsed .scheme != "file" :
213+ return file_uri
214+
215+ netloc = unquote (parsed .netloc or "" )
216+ path = unquote (parsed .path or "" )
217+ if re .fullmatch (r"[A-Za-z]:" , netloc ):
218+ return str (Path (f"{ netloc } { path } " ))
219+ if re .match (r"^/[A-Za-z]:/" , path ):
220+ path = path [1 :]
221+ if netloc and netloc != "localhost" :
222+ path = f"//{ netloc } { path } "
223+ return str (Path (path ))
224+
225+ async def _image_ref_to_data_url (
226+ self ,
227+ image_ref : str ,
228+ * ,
229+ mode : Literal ["safe" , "strict" ] = "safe" ,
230+ ) -> str | None :
231+ if image_ref .startswith ("base64://" ):
232+ return image_ref .replace ("base64://" , "data:image/jpeg;base64," )
233+
234+ if image_ref .startswith ("http" ):
235+ image_path = await download_image_by_url (image_ref )
236+ elif image_ref .startswith ("file://" ):
237+ image_path = self ._file_uri_to_path (image_ref )
238+ else :
239+ image_path = image_ref
240+
241+ return self ._encode_image_file_to_data_url (
242+ image_path ,
243+ mode = mode ,
244+ )
245+
246+ async def _resolve_image_part (
247+ self ,
248+ image_url : str ,
249+ * ,
250+ image_detail : str | None = None ,
251+ ) -> dict | None :
252+ if image_url .startswith ("data:" ):
253+ image_payload = {"url" : image_url }
254+ else :
255+ image_data = await self ._image_ref_to_data_url (image_url , mode = "safe" )
256+ if not image_data :
257+ logger .warning (f"图片 { image_url } 得到的结果为空,将忽略。" )
258+ return None
259+ image_payload = {"url" : image_data }
260+
261+ if image_detail :
262+ image_payload ["detail" ] = image_detail
263+ return {
264+ "type" : "image_url" ,
265+ "image_url" : image_payload ,
266+ }
267+
268+ def _extract_image_part_info (self , part : dict ) -> tuple [str | None , str | None ]:
269+ if not isinstance (part , dict ) or part .get ("type" ) != "image_url" :
270+ return None , None
271+
272+ image_url_data = part .get ("image_url" )
273+ if not isinstance (image_url_data , dict ):
274+ logger .warning ("图片内容块格式无效,将保留原始内容。" )
275+ return None , None
276+
277+ url = image_url_data .get ("url" )
278+ if not isinstance (url , str ) or not url :
279+ logger .warning ("图片内容块缺少有效 URL,将保留原始内容。" )
280+ return None , None
281+
282+ image_detail = image_url_data .get ("detail" )
283+ if not isinstance (image_detail , str ):
284+ image_detail = None
285+ return url , image_detail
286+
287+ async def _transform_content_part (self , part : dict ) -> dict :
288+ url , image_detail = self ._extract_image_part_info (part )
289+ if not url :
290+ return part
291+
292+ try :
293+ resolved_part = await self ._resolve_image_part (
294+ url , image_detail = image_detail
295+ )
296+ except Exception as exc :
297+ logger .warning (
298+ "图片 %s 预处理失败,将保留原始内容。错误: %s" ,
299+ url ,
300+ exc ,
301+ )
302+ return part
303+
304+ return resolved_part or part
305+
306+ async def _materialize_message_image_parts (self , message : dict ) -> dict :
307+ content = message .get ("content" )
308+ if not isinstance (content , list ):
309+ return {** message }
310+
311+ new_content = [await self ._transform_content_part (part ) for part in content ]
312+ return {** message , "content" : new_content }
313+
314+ async def _materialize_context_image_parts (
315+ self , context_query : list [dict ]
316+ ) -> list [dict ]:
317+ return [
318+ await self ._materialize_message_image_parts (message )
319+ for message in context_query
320+ ]
321+
136322 async def _fallback_to_text_only_and_retry (
137323 self ,
138324 payloads : dict ,
@@ -604,7 +790,7 @@ async def _prepare_chat_payload(
604790 new_record = await self .assemble_context (
605791 prompt , image_urls , extra_user_content_parts
606792 )
607- context_query = self ._ensure_message_to_dicts (contexts )
793+ context_query = copy . deepcopy ( self ._ensure_message_to_dicts (contexts ) )
608794 if new_record :
609795 context_query .append (new_record )
610796 if system_prompt :
@@ -622,6 +808,9 @@ async def _prepare_chat_payload(
622808 for tcr in tool_calls_result :
623809 context_query .extend (tcr .to_openai_messages ())
624810
811+ if self ._context_contains_image (context_query ):
812+ context_query = await self ._materialize_context_image_parts (context_query )
813+
625814 model = model or self .get_model ()
626815 payloads = {** kwargs , "messages" : context_query , "model" : model }
627816
@@ -721,6 +910,18 @@ async def _handle_api_error(
721910 "image_content_moderated" ,
722911 image_fallback_used = True ,
723912 )
913+ if self ._is_invalid_attachment_error (e ):
914+ if image_fallback_used or not self ._context_contains_image (context_query ):
915+ raise e
916+ return await self ._fallback_to_text_only_and_retry (
917+ payloads ,
918+ context_query ,
919+ chosen_key ,
920+ available_api_keys ,
921+ func_tool ,
922+ "invalid_attachment" ,
923+ image_fallback_used = True ,
924+ )
724925
725926 if (
726927 "Function calling is not enabled" in str (e )
@@ -922,23 +1123,6 @@ async def assemble_context(
9221123 ) -> dict :
9231124 """组装成符合 OpenAI 格式的 role 为 user 的消息段"""
9241125
925- async def resolve_image_part (image_url : str ) -> dict | None :
926- if image_url .startswith ("http" ):
927- image_path = await download_image_by_url (image_url )
928- image_data = await self .encode_image_bs64 (image_path )
929- elif image_url .startswith ("file:///" ):
930- image_path = image_url .replace ("file:///" , "" )
931- image_data = await self .encode_image_bs64 (image_path )
932- else :
933- image_data = await self .encode_image_bs64 (image_url )
934- if not image_data :
935- logger .warning (f"图片 { image_url } 得到的结果为空,将忽略。" )
936- return None
937- return {
938- "type" : "image_url" ,
939- "image_url" : {"url" : image_data },
940- }
941-
9421126 # 构建内容块列表
9431127 content_blocks = []
9441128
@@ -958,7 +1142,9 @@ async def resolve_image_part(image_url: str) -> dict | None:
9581142 if isinstance (part , TextPart ):
9591143 content_blocks .append ({"type" : "text" , "text" : part .text })
9601144 elif isinstance (part , ImageURLPart ):
961- image_part = await resolve_image_part (part .image_url .url )
1145+ image_part = await self ._resolve_image_part (
1146+ part .image_url .url ,
1147+ )
9621148 if image_part :
9631149 content_blocks .append (image_part )
9641150 else :
@@ -967,7 +1153,7 @@ async def resolve_image_part(image_url: str) -> dict | None:
9671153 # 3. 图片内容
9681154 if image_urls :
9691155 for image_url in image_urls :
970- image_part = await resolve_image_part (image_url )
1156+ image_part = await self . _resolve_image_part (image_url )
9711157 if image_part :
9721158 content_blocks .append (image_part )
9731159
@@ -986,11 +1172,10 @@ async def resolve_image_part(image_url: str) -> dict | None:
9861172
9871173 async def encode_image_bs64 (self , image_url : str ) -> str :
9881174 """将图片转换为 base64"""
989- if image_url .startswith ("base64://" ):
990- return image_url .replace ("base64://" , "data:image/jpeg;base64," )
991- with open (image_url , "rb" ) as f :
992- image_bs64 = base64 .b64encode (f .read ()).decode ("utf-8" )
993- return "data:image/jpeg;base64," + image_bs64
1175+ image_data = await self ._image_ref_to_data_url (image_url , mode = "strict" )
1176+ if image_data is None :
1177+ raise RuntimeError (f"Failed to encode image data: { image_url } " )
1178+ return image_data
9941179
9951180 async def terminate (self ):
9961181 if self .client :
0 commit comments