Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions astrbot/core/astr_main_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
retrieve_knowledge_base,
)
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Reply
from astrbot.core.message.components import File, Image, Record, Reply
from astrbot.core.persona_error_reply import (
extract_persona_custom_error_message_from_persona,
set_persona_custom_error_message_on_event,
Expand Down Expand Up @@ -515,6 +515,18 @@ def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> No
)


def _append_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
req.extra_user_content_parts.append(
TextPart(text=f"[Audio Attachment: path {audio_path}]")
)


def _append_quoted_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
req.extra_user_content_parts.append(
TextPart(text=f"[Audio Attachment in quoted message: path {audio_path}]")
)


def _get_quoted_message_parser_settings(
provider_settings: dict[str, object] | None,
) -> QuotedMessageParserSettings:
Expand Down Expand Up @@ -753,12 +765,25 @@ def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
"Provider %s does not support image, using placeholder.", provider
)
image_count = len(req.image_urls)
placeholder = " ".join(["[图片]"] * image_count)
placeholder = " ".join(["[Image]"] * image_count)
if req.prompt:
req.prompt = f"{placeholder} {req.prompt}"
else:
req.prompt = placeholder
req.image_urls = []
if req.audio_urls:
provider_cfg = provider.provider_config.get("modalities", ["audio"])
if "audio" not in provider_cfg:
logger.debug(
"Provider %s does not support audio, using placeholder.", provider
)
audio_count = len(req.audio_urls)
placeholder = " ".join(["[Audio]"] * audio_count)
if req.prompt:
req.prompt = f"{placeholder} {req.prompt}"
else:
req.prompt = placeholder
req.audio_urls = []
if req.func_tool:
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
if "tool_use" not in provider_cfg:
Expand All @@ -781,12 +806,14 @@ def _sanitize_context_by_modalities(
if not modalities or not isinstance(modalities, list):
return
supports_image = bool("image" in modalities)
supports_audio = bool("audio" in modalities)
supports_tool_use = bool("tool_use" in modalities)
if supports_image and supports_tool_use:
if supports_image and supports_audio and supports_tool_use:
return

sanitized_contexts: list[dict] = []
removed_image_blocks = 0
removed_audio_blocks = 0
removed_tool_messages = 0
removed_tool_calls = 0

Expand All @@ -808,20 +835,27 @@ def _sanitize_context_by_modalities(
new_msg.pop("tool_calls", None)
new_msg.pop("tool_call_id", None)

if not supports_image:
if not supports_image or not supports_audio:
content = new_msg.get("content")
if isinstance(content, list):
filtered_parts: list = []
removed_any_image = False
removed_any_multimodal = False
for part in content:
if isinstance(part, dict):
part_type = str(part.get("type", "")).lower()
if part_type in {"image_url", "image"}:
removed_any_image = True
if not supports_image and part_type in {"image_url", "image"}:
removed_any_multimodal = True
removed_image_blocks += 1
continue
if not supports_audio and part_type in {
"audio_url",
"input_audio",
}:
removed_any_multimodal = True
removed_audio_blocks += 1
continue
filtered_parts.append(part)
if removed_any_image:
if removed_any_multimodal:
new_msg["content"] = filtered_parts

if role == "assistant":
Expand All @@ -835,11 +869,18 @@ def _sanitize_context_by_modalities(

sanitized_contexts.append(new_msg)

if removed_image_blocks or removed_tool_messages or removed_tool_calls:
if (
removed_image_blocks
or removed_audio_blocks
or removed_tool_messages
or removed_tool_calls
):
logger.debug(
"sanitize_context_by_modalities applied: "
"removed_image_blocks=%s, removed_tool_messages=%s, removed_tool_calls=%s",
"removed_image_blocks=%s, removed_audio_blocks=%s, "
"removed_tool_messages=%s, removed_tool_calls=%s",
removed_image_blocks,
removed_audio_blocks,
removed_tool_messages,
removed_tool_calls,
)
Expand Down Expand Up @@ -1101,6 +1142,7 @@ async def build_main_agent(
req = ProviderRequest()
req.prompt = ""
req.image_urls = []
req.audio_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if config.provider_wake_prefix and not event.message_str.startswith(
Expand All @@ -1124,6 +1166,10 @@ async def build_main_agent(
req.extra_user_content_parts.append(
TextPart(text=f"[Image Attachment: path {image_path}]")
)
elif isinstance(comp, Record):
audio_path = await comp.convert_to_file_path()
req.audio_urls.append(audio_path)
_append_audio_attachment(req, audio_path)
elif isinstance(comp, File):
file_path = await comp.get_file()
file_name = comp.name or os.path.basename(file_path)
Expand Down Expand Up @@ -1155,6 +1201,10 @@ async def build_main_agent(
event.track_temporary_local_file(image_path)
req.image_urls.append(image_path)
_append_quoted_image_attachment(req, image_path)
elif isinstance(reply_comp, Record):
audio_path = await reply_comp.convert_to_file_path()
req.audio_urls.append(audio_path)
_append_quoted_audio_attachment(req, audio_path)
elif isinstance(reply_comp, File):
file_path = await reply_comp.get_file()
file_name = reply_comp.name or os.path.basename(file_path)
Expand Down Expand Up @@ -1222,14 +1272,15 @@ async def build_main_agent(
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
req.image_urls = normalize_and_dedupe_strings(req.image_urls)
req.audio_urls = normalize_and_dedupe_strings(req.audio_urls)

if config.file_extract_enabled:
try:
await _apply_file_extract(event, req, config)
except Exception as exc: # noqa: BLE001
logger.error("Error occurred while applying file extract: %s", exc)

if not req.prompt and not req.image_urls:
if not req.prompt and not req.image_urls and not req.audio_urls:
if not event.get_group_id() and req.extra_user_content_parts:
req.prompt = "<attachment>"
else:
Expand Down
4 changes: 2 additions & 2 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1874,8 +1874,8 @@ class ChatProviderTemplate(TypedDict):
"description": "模型能力",
"type": "list",
"items": {"type": "string"},
"options": ["text", "image", "tool_use"],
"labels": ["文本", "图像", "工具使用"],
"options": ["text", "image", "audio", "tool_use"],
"labels": ["文本", "图像", "音频", "工具使用"],
"render_type": "checkbox",
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
},
Expand Down
23 changes: 0 additions & 23 deletions astrbot/core/message/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class ComponentType(str, Enum):
Music = "Music"
Json = "Json"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包


class BaseMessageComponent(BaseModel):
Expand All @@ -91,7 +90,6 @@ async def to_dict(self) -> dict:
class Plain(BaseMessageComponent):
type: ComponentType = ComponentType.Plain
text: str
convert: bool | None = True

def __init__(self, text: str, convert: bool = True, **_) -> None:
super().__init__(text=text, convert=convert, **_)
Expand All @@ -114,11 +112,7 @@ def __init__(self, **_) -> None:
class Record(BaseMessageComponent):
type: ComponentType = ComponentType.Record
file: str | None = ""
magic: bool | None = False
url: str | None = ""
cache: bool | None = True
proxy: bool | None = True
timeout: int | None = 0
# Original text content (e.g. TTS source text), used as caption in fallback scenarios
text: str | None = None
# 额外
Expand Down Expand Up @@ -224,7 +218,6 @@ class Video(BaseMessageComponent):
type: ComponentType = ComponentType.Video
file: str
cover: str | None = ""
c: int | None = 2
# 额外
path: str | None = ""

Expand Down Expand Up @@ -401,14 +394,9 @@ class Image(BaseMessageComponent):
type: ComponentType = ComponentType.Image
file: str | None = ""
_type: str | None = ""
subType: int | None = 0
url: str | None = ""
cache: bool | None = True
id: int | None = 40000
c: int | None = 2
# 额外
path: str | None = ""
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识

def __init__(self, file: str | None, **_) -> None:
super().__init__(file=file, **_)
Expand Down Expand Up @@ -839,16 +827,6 @@ async def to_dict(self):
}


class WechatEmoji(BaseMessageComponent):
type: ComponentType = ComponentType.WechatEmoji
md5: str | None = ""
md5_len: int | None = 0
cdnurl: str | None = ""

def __init__(self, **_) -> None:
super().__init__(**_)


ComponentTypes = {
# Basic Message Segments
"plain": Plain,
Expand All @@ -874,5 +852,4 @@ def __init__(self, **_) -> None:
"nodes": Nodes,
"json": Json,
"unknown": Unknown,
"WechatEmoji": WechatEmoji,
}
16 changes: 16 additions & 0 deletions astrbot/core/pipeline/preprocess_stage/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from astrbot.core import logger
from astrbot.core.message.components import Image, Plain, Record
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.media_utils import ensure_wav

from ..context import PipelineContext
from ..stage import Stage, register_stage
Expand Down Expand Up @@ -64,6 +65,21 @@ async def process(
logger.debug(f"路径映射: {url} -> {component.url}")
message_chain[idx] = component

# In here, we convert all Record components to wav format and update the file path.
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record):
try:
original_path = await component.convert_to_file_path()
record_path = await ensure_wav(original_path)
if record_path != original_path:
event.track_temporary_local_file(record_path)
component.file = record_path
component.path = record_path
message_chain[idx] = component
except Exception as e:
logger.warning(f"Voice processing failed: {e}")

# STT
if self.stt_settings.get("enable", False):
# TODO: 独立
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MainAgentBuildResult,
build_main_agent,
)
from astrbot.core.message.components import File, Image
from astrbot.core.message.components import File, Image, Record, Video
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
Expand Down Expand Up @@ -153,7 +153,8 @@ async def process(
has_provider_request = event.get_extra("provider_request") is not None
has_valid_message = bool(event.message_str and event.message_str.strip())
has_media_content = any(
isinstance(comp, Image | File) for comp in event.message_obj.message
isinstance(comp, (Image, File, Record, Video))
for comp in event.message_obj.message
)

if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.message.components import Image
from astrbot.core.message.components import Image, Record
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
Expand Down Expand Up @@ -317,8 +317,11 @@ async def process(
if isinstance(comp, Image):
image_path = await comp.convert_to_base64()
req.image_urls.append(image_path)
elif isinstance(comp, Record):
audio_path = await comp.convert_to_file_path()
req.audio_urls.append(audio_path)

if not req.prompt and not req.image_urls:
if not req.prompt and not req.image_urls and not req.audio_urls:
return

custom_error_message = await self._resolve_persona_custom_error_message(event)
Expand Down
1 change: 0 additions & 1 deletion astrbot/core/pipeline/respond/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class RespondStage(Stage):
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
Comp.Json: lambda comp: bool(comp.data), # Json 卡片
Comp.Share: lambda comp: bool(comp.url) or bool(comp.title),
Comp.Music: lambda comp: (
Expand Down
6 changes: 6 additions & 0 deletions astrbot/core/platform/astr_message_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def request_llm(
tool_set: ToolSet | None = None,
session_id: str = "",
image_urls: list[str] | None = None,
audio_urls: list[str] | None = None,
contexts: list | None = None,
system_prompt: str = "",
conversation: Conversation | None = None,
Expand All @@ -432,6 +433,8 @@ def request_llm(

image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。

audio_urls: 音频 URL 列表,也支持本地路径。

contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。如果同时传入了 conversation,将会忽略 conversation。

func_tool_manager: [Deprecated] 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。已过时,请使用 tool_set 参数代替。
Expand All @@ -441,6 +444,8 @@ def request_llm(
"""
if image_urls is None:
image_urls = []
if audio_urls is None:
audio_urls = []
if contexts is None:
contexts = []
if len(contexts) > 0 and conversation:
Expand All @@ -450,6 +455,7 @@ def request_llm(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
audio_urls=audio_urls,
# func_tool=func_tool_manager,
func_tool=tool_set,
contexts=contexts,
Expand Down
Loading
Loading