Skip to content

Commit b2d71e2

Browse files
Soultera61995987gemini-code-assist[bot]
authored
feat: supports image compressing (#6794)
* feat: supports image compressing (#6463) Co-authored-by: Soulter <905617992@qq.com> * feat: 增加图像压缩最大尺寸至1280 * Update astrbot/core/astr_main_agent.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * feat: 增强临时文件管理,添加图像压缩路径跟踪与清理功能 * feat: 更新图片压缩功能提示,移除对 chat_completion 提供商的限制说明 --------- Co-authored-by: Chen <42998804+a61995987@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 63dd28d commit b2d71e2

8 files changed

Lines changed: 296 additions & 4 deletions

File tree

astrbot/core/astr_main_agent.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@
6969
)
7070
from astrbot.core.utils.file_extract import extract_file_moonshotai
7171
from astrbot.core.utils.llm_metadata import LLM_METADATAS
72+
from astrbot.core.utils.media_utils import (
73+
IMAGE_COMPRESS_DEFAULT_MAX_SIZE,
74+
IMAGE_COMPRESS_DEFAULT_QUALITY,
75+
compress_image,
76+
)
7277
from astrbot.core.utils.quoted_message.settings import (
7378
SETTINGS as DEFAULT_QUOTED_MESSAGE_SETTINGS,
7479
)
@@ -473,16 +478,23 @@ async def _request_img_caption(
473478

474479

475480
async def _ensure_img_caption(
481+
event: AstrMessageEvent,
476482
req: ProviderRequest,
477483
cfg: dict,
478484
plugin_context: Context,
479485
image_caption_provider: str,
480486
) -> None:
481487
try:
488+
compressed_urls = []
489+
for url in req.image_urls:
490+
compressed_url = await _compress_image_for_provider(url, cfg)
491+
compressed_urls.append(compressed_url)
492+
if _is_generated_compressed_image_path(url, compressed_url):
493+
event.track_temporary_local_file(compressed_url)
482494
caption = await _request_img_caption(
483495
image_caption_provider,
484496
cfg,
485-
req.image_urls,
497+
compressed_urls,
486498
plugin_context,
487499
)
488500
if caption:
@@ -492,6 +504,9 @@ async def _ensure_img_caption(
492504
req.image_urls = []
493505
except Exception as exc: # noqa: BLE001
494506
logger.error("处理图片描述失败: %s", exc)
507+
req.extra_user_content_parts.append(TextPart(text="[Image Captioning Failed]"))
508+
finally:
509+
req.image_urls = []
495510

496511

497512
def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None:
@@ -511,12 +526,64 @@ def _get_quoted_message_parser_settings(
511526
return DEFAULT_QUOTED_MESSAGE_SETTINGS.with_overrides(overrides)
512527

513528

529+
def _get_image_compress_args(
530+
provider_settings: dict[str, object] | None,
531+
) -> tuple[bool, int, int]:
532+
if not isinstance(provider_settings, dict):
533+
return True, IMAGE_COMPRESS_DEFAULT_MAX_SIZE, IMAGE_COMPRESS_DEFAULT_QUALITY
534+
535+
enabled = provider_settings.get("image_compress_enabled", True)
536+
if not isinstance(enabled, bool):
537+
enabled = True
538+
539+
raw_options = provider_settings.get("image_compress_options", {})
540+
options = raw_options if isinstance(raw_options, dict) else {}
541+
542+
max_size = options.get("max_size", IMAGE_COMPRESS_DEFAULT_MAX_SIZE)
543+
if not isinstance(max_size, int):
544+
max_size = IMAGE_COMPRESS_DEFAULT_MAX_SIZE
545+
max_size = max(max_size, 1)
546+
547+
quality = options.get("quality", IMAGE_COMPRESS_DEFAULT_QUALITY)
548+
if not isinstance(quality, int):
549+
quality = IMAGE_COMPRESS_DEFAULT_QUALITY
550+
quality = min(max(quality, 1), 100)
551+
552+
return enabled, max_size, quality
553+
554+
555+
async def _compress_image_for_provider(
556+
url_or_path: str,
557+
provider_settings: dict[str, object] | None,
558+
) -> str:
559+
try:
560+
enabled, max_size, quality = _get_image_compress_args(provider_settings)
561+
if not enabled:
562+
return url_or_path
563+
return await compress_image(url_or_path, max_size=max_size, quality=quality)
564+
except Exception as exc: # noqa: BLE001
565+
logger.error("Image compression failed: %s", exc)
566+
return url_or_path
567+
568+
569+
def _is_generated_compressed_image_path(
570+
original_path: str,
571+
compressed_path: str | None,
572+
) -> bool:
573+
if not compressed_path or compressed_path == original_path:
574+
return False
575+
if compressed_path.startswith("http") or compressed_path.startswith("data:image"):
576+
return False
577+
return os.path.exists(compressed_path)
578+
579+
514580
async def _process_quote_message(
515581
event: AstrMessageEvent,
516582
req: ProviderRequest,
517583
img_cap_prov_id: str,
518584
plugin_context: Context,
519585
quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
586+
config: MainAgentBuildConfig | None = None,
520587
) -> None:
521588
quote = None
522589
for comp in event.message_obj.message:
@@ -549,15 +616,24 @@ async def _process_quote_message(
549616
if image_seg:
550617
try:
551618
prov = None
619+
path = None
620+
compress_path = None
552621
if img_cap_prov_id:
553622
prov = plugin_context.get_provider_by_id(img_cap_prov_id)
554623
if prov is None:
555624
prov = plugin_context.get_using_provider(event.unified_msg_origin)
556625

557626
if prov and isinstance(prov, Provider):
627+
path = await image_seg.convert_to_file_path()
628+
compress_path = await _compress_image_for_provider(
629+
path,
630+
config.provider_settings if config else None,
631+
)
632+
if path and _is_generated_compressed_image_path(path, compress_path):
633+
event.track_temporary_local_file(compress_path)
558634
llm_resp = await prov.text_chat(
559635
prompt="Please describe the image content.",
560-
image_urls=[await image_seg.convert_to_file_path()],
636+
image_urls=[compress_path],
561637
)
562638
if llm_resp.completion_text:
563639
content_parts.append(
@@ -567,6 +643,16 @@ async def _process_quote_message(
567643
logger.warning("No provider found for image captioning in quote.")
568644
except BaseException as exc:
569645
logger.error("处理引用图片失败: %s", exc)
646+
finally:
647+
if (
648+
compress_path
649+
and compress_path != path
650+
and os.path.exists(compress_path)
651+
):
652+
try:
653+
os.remove(compress_path)
654+
except Exception as exc: # noqa: BLE001
655+
logger.warning("Fail to remove temporary compressed image: %s", exc)
570656

571657
quoted_content = "\n".join(content_parts)
572658
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
@@ -635,6 +721,7 @@ async def _decorate_llm_request(
635721
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
636722
if img_cap_prov_id and req.image_urls:
637723
await _ensure_img_caption(
724+
event,
638725
req,
639726
cfg,
640727
plugin_context,
@@ -649,6 +736,7 @@ async def _decorate_llm_request(
649736
img_cap_prov_id,
650737
plugin_context,
651738
quoted_message_settings,
739+
config,
652740
)
653741

654742
tz = config.timezone
@@ -1025,7 +1113,13 @@ async def build_main_agent(
10251113
# media files attachments
10261114
for comp in event.message_obj.message:
10271115
if isinstance(comp, Image):
1028-
image_path = await comp.convert_to_file_path()
1116+
path = await comp.convert_to_file_path()
1117+
image_path = await _compress_image_for_provider(
1118+
path,
1119+
config.provider_settings,
1120+
)
1121+
if _is_generated_compressed_image_path(path, image_path):
1122+
event.track_temporary_local_file(image_path)
10291123
req.image_urls.append(image_path)
10301124
req.extra_user_content_parts.append(
10311125
TextPart(text=f"[Image Attachment: path {image_path}]")
@@ -1052,7 +1146,13 @@ async def build_main_agent(
10521146
for reply_comp in comp.chain:
10531147
if isinstance(reply_comp, Image):
10541148
has_embedded_image = True
1055-
image_path = await reply_comp.convert_to_file_path()
1149+
path = await reply_comp.convert_to_file_path()
1150+
image_path = await _compress_image_for_provider(
1151+
path,
1152+
config.provider_settings,
1153+
)
1154+
if _is_generated_compressed_image_path(path, image_path):
1155+
event.track_temporary_local_file(image_path)
10561156
req.image_urls.append(image_path)
10571157
_append_quoted_image_attachment(req, image_path)
10581158
elif isinstance(reply_comp, File):

astrbot/core/config/default.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@
174174
"shipyard_neo_profile": "python-default",
175175
"shipyard_neo_ttl": 3600,
176176
},
177+
"image_compress_enabled": True,
178+
"image_compress_options": {
179+
"max_size": 1280,
180+
"quality": 95,
181+
},
177182
},
178183
# SubAgent orchestrator mode:
179184
# - main_enable = False: disabled; main LLM mounts tools normally (persona selection).
@@ -3452,6 +3457,29 @@ class ChatProviderTemplate(TypedDict):
34523457
"type": "string",
34533458
"hint": "可使用 {{prompt}} 作为用户输入的占位符。如果不输入占位符则代表添加在用户输入的前面。",
34543459
},
3460+
"provider_settings.image_compress_enabled": {
3461+
"description": "启用图片压缩",
3462+
"type": "bool",
3463+
"hint": "启用后,发送给多模态模型前会先压缩本地大图片。",
3464+
},
3465+
"provider_settings.image_compress_options.max_size": {
3466+
"description": "最大边长",
3467+
"type": "int",
3468+
"hint": "压缩后图片的最长边,单位为像素。超过该尺寸时会按比例缩放。",
3469+
"condition": {
3470+
"provider_settings.image_compress_enabled": True,
3471+
},
3472+
"slider": {"min": 256, "max": 4096, "step": 64},
3473+
},
3474+
"provider_settings.image_compress_options.quality": {
3475+
"description": "压缩质量",
3476+
"type": "int",
3477+
"hint": "JPEG 输出质量,范围为 1-100。值越高,画质越好,文件也越大。",
3478+
"condition": {
3479+
"provider_settings.image_compress_enabled": True,
3480+
},
3481+
"slider": {"min": 1, "max": 100, "step": 1},
3482+
},
34553483
"provider_tts_settings.dual_output": {
34563484
"description": "开启 TTS 时同时输出语音和文字内容",
34573485
"type": "bool",

astrbot/core/pipeline/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,5 @@ async def execute(self, event: AstrMessageEvent) -> None:
9292

9393
logger.debug("pipeline 执行完毕。")
9494
finally:
95+
event.cleanup_temporary_local_files()
9596
active_event_registry.unregister(event)

astrbot/core/platform/astr_message_event.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import asyncio
33
import hashlib
4+
import os
45
import re
56
import uuid
67
from collections.abc import AsyncGenerator
@@ -88,6 +89,8 @@ def __init__(
8889
"""在此次事件中是否有过至少一次发送消息的操作"""
8990
self.call_llm = False
9091
"""是否在此消息事件中禁止默认的 LLM 请求"""
92+
self._temporary_local_files: list[str] = []
93+
"""Temporary local files created during this event and safe to delete when it finishes."""
9194

9295
self.plugins_name: list[str] | None = None
9396
"""该事件启用的插件名称列表。None 表示所有插件都启用。空列表表示没有启用任何插件。"""
@@ -228,6 +231,24 @@ def clear_extra(self) -> None:
228231
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
229232
self._extras.clear()
230233

234+
def track_temporary_local_file(self, path: str) -> None:
235+
if path and path not in self._temporary_local_files:
236+
self._temporary_local_files.append(path)
237+
238+
def cleanup_temporary_local_files(self) -> None:
239+
paths = list(self._temporary_local_files)
240+
self._temporary_local_files.clear()
241+
for path in paths:
242+
try:
243+
if os.path.exists(path):
244+
os.remove(path)
245+
except OSError as e:
246+
logger.warning(
247+
"Failed to remove temporary local file %s: %s",
248+
path,
249+
e,
250+
)
251+
231252
def is_private_chat(self) -> bool:
232253
"""是否是私聊。"""
233254
return self.get_message_type() == MessageType.FRIEND_MESSAGE

0 commit comments

Comments
 (0)