Skip to content

Commit 224287e

Browse files
authored
feat: add audio input support across providers and chatui recording issue fix (#7378)
* feat: add audio input support across providers and chatui recording issue fix - Introduced audio_urls parameter in Provider class and related methods to handle audio input. - Updated ProviderAnthropic, ProviderGoogleGenAI, and ProviderOpenAIOfficial to process audio URLs. - Enhanced media_utils with functions to ensure audio format compatibility and detect audio types. - Modified dashboard components to display audio input support and handle audio attachments in messages. - Updated localization files to include audio as a supported modality. - Added new icons for audio input in the dashboard UI. * feat: enhance audio handling with temporary file cleanup and format support * feat: track temporary local files for converted audio components * fix: update image placeholder in prompt from "[图片]" to "[Image]"
1 parent 80d5efd commit 224287e

33 files changed

Lines changed: 628 additions & 99 deletions

astrbot/core/astr_main_agent.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
retrieve_knowledge_base,
5252
)
5353
from astrbot.core.conversation_mgr import Conversation
54-
from astrbot.core.message.components import File, Image, Reply
54+
from astrbot.core.message.components import File, Image, Record, Reply
5555
from astrbot.core.persona_error_reply import (
5656
extract_persona_custom_error_message_from_persona,
5757
set_persona_custom_error_message_on_event,
@@ -515,6 +515,18 @@ def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> No
515515
)
516516

517517

518+
def _append_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
519+
req.extra_user_content_parts.append(
520+
TextPart(text=f"[Audio Attachment: path {audio_path}]")
521+
)
522+
523+
524+
def _append_quoted_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
525+
req.extra_user_content_parts.append(
526+
TextPart(text=f"[Audio Attachment in quoted message: path {audio_path}]")
527+
)
528+
529+
518530
def _get_quoted_message_parser_settings(
519531
provider_settings: dict[str, object] | None,
520532
) -> QuotedMessageParserSettings:
@@ -753,12 +765,25 @@ def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
753765
"Provider %s does not support image, using placeholder.", provider
754766
)
755767
image_count = len(req.image_urls)
756-
placeholder = " ".join(["[图片]"] * image_count)
768+
placeholder = " ".join(["[Image]"] * image_count)
757769
if req.prompt:
758770
req.prompt = f"{placeholder} {req.prompt}"
759771
else:
760772
req.prompt = placeholder
761773
req.image_urls = []
774+
if req.audio_urls:
775+
provider_cfg = provider.provider_config.get("modalities", ["audio"])
776+
if "audio" not in provider_cfg:
777+
logger.debug(
778+
"Provider %s does not support audio, using placeholder.", provider
779+
)
780+
audio_count = len(req.audio_urls)
781+
placeholder = " ".join(["[Audio]"] * audio_count)
782+
if req.prompt:
783+
req.prompt = f"{placeholder} {req.prompt}"
784+
else:
785+
req.prompt = placeholder
786+
req.audio_urls = []
762787
if req.func_tool:
763788
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
764789
if "tool_use" not in provider_cfg:
@@ -781,12 +806,14 @@ def _sanitize_context_by_modalities(
781806
if not modalities or not isinstance(modalities, list):
782807
return
783808
supports_image = bool("image" in modalities)
809+
supports_audio = bool("audio" in modalities)
784810
supports_tool_use = bool("tool_use" in modalities)
785-
if supports_image and supports_tool_use:
811+
if supports_image and supports_audio and supports_tool_use:
786812
return
787813

788814
sanitized_contexts: list[dict] = []
789815
removed_image_blocks = 0
816+
removed_audio_blocks = 0
790817
removed_tool_messages = 0
791818
removed_tool_calls = 0
792819

@@ -808,20 +835,27 @@ def _sanitize_context_by_modalities(
808835
new_msg.pop("tool_calls", None)
809836
new_msg.pop("tool_call_id", None)
810837

811-
if not supports_image:
838+
if not supports_image or not supports_audio:
812839
content = new_msg.get("content")
813840
if isinstance(content, list):
814841
filtered_parts: list = []
815-
removed_any_image = False
842+
removed_any_multimodal = False
816843
for part in content:
817844
if isinstance(part, dict):
818845
part_type = str(part.get("type", "")).lower()
819-
if part_type in {"image_url", "image"}:
820-
removed_any_image = True
846+
if not supports_image and part_type in {"image_url", "image"}:
847+
removed_any_multimodal = True
821848
removed_image_blocks += 1
822849
continue
850+
if not supports_audio and part_type in {
851+
"audio_url",
852+
"input_audio",
853+
}:
854+
removed_any_multimodal = True
855+
removed_audio_blocks += 1
856+
continue
823857
filtered_parts.append(part)
824-
if removed_any_image:
858+
if removed_any_multimodal:
825859
new_msg["content"] = filtered_parts
826860

827861
if role == "assistant":
@@ -835,11 +869,18 @@ def _sanitize_context_by_modalities(
835869

836870
sanitized_contexts.append(new_msg)
837871

838-
if removed_image_blocks or removed_tool_messages or removed_tool_calls:
872+
if (
873+
removed_image_blocks
874+
or removed_audio_blocks
875+
or removed_tool_messages
876+
or removed_tool_calls
877+
):
839878
logger.debug(
840879
"sanitize_context_by_modalities applied: "
841-
"removed_image_blocks=%s, removed_tool_messages=%s, removed_tool_calls=%s",
880+
"removed_image_blocks=%s, removed_audio_blocks=%s, "
881+
"removed_tool_messages=%s, removed_tool_calls=%s",
842882
removed_image_blocks,
883+
removed_audio_blocks,
843884
removed_tool_messages,
844885
removed_tool_calls,
845886
)
@@ -1101,6 +1142,7 @@ async def build_main_agent(
11011142
req = ProviderRequest()
11021143
req.prompt = ""
11031144
req.image_urls = []
1145+
req.audio_urls = []
11041146
if sel_model := event.get_extra("selected_model"):
11051147
req.model = sel_model
11061148
if config.provider_wake_prefix and not event.message_str.startswith(
@@ -1124,6 +1166,10 @@ async def build_main_agent(
11241166
req.extra_user_content_parts.append(
11251167
TextPart(text=f"[Image Attachment: path {image_path}]")
11261168
)
1169+
elif isinstance(comp, Record):
1170+
audio_path = await comp.convert_to_file_path()
1171+
req.audio_urls.append(audio_path)
1172+
_append_audio_attachment(req, audio_path)
11271173
elif isinstance(comp, File):
11281174
file_path = await comp.get_file()
11291175
file_name = comp.name or os.path.basename(file_path)
@@ -1155,6 +1201,10 @@ async def build_main_agent(
11551201
event.track_temporary_local_file(image_path)
11561202
req.image_urls.append(image_path)
11571203
_append_quoted_image_attachment(req, image_path)
1204+
elif isinstance(reply_comp, Record):
1205+
audio_path = await reply_comp.convert_to_file_path()
1206+
req.audio_urls.append(audio_path)
1207+
_append_quoted_audio_attachment(req, audio_path)
11581208
elif isinstance(reply_comp, File):
11591209
file_path = await reply_comp.get_file()
11601210
file_name = reply_comp.name or os.path.basename(file_path)
@@ -1222,14 +1272,15 @@ async def build_main_agent(
12221272
if isinstance(req.contexts, str):
12231273
req.contexts = json.loads(req.contexts)
12241274
req.image_urls = normalize_and_dedupe_strings(req.image_urls)
1275+
req.audio_urls = normalize_and_dedupe_strings(req.audio_urls)
12251276

12261277
if config.file_extract_enabled:
12271278
try:
12281279
await _apply_file_extract(event, req, config)
12291280
except Exception as exc: # noqa: BLE001
12301281
logger.error("Error occurred while applying file extract: %s", exc)
12311282

1232-
if not req.prompt and not req.image_urls:
1283+
if not req.prompt and not req.image_urls and not req.audio_urls:
12331284
if not event.get_group_id() and req.extra_user_content_parts:
12341285
req.prompt = "<attachment>"
12351286
else:

astrbot/core/config/default.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,8 +1874,8 @@ class ChatProviderTemplate(TypedDict):
18741874
"description": "模型能力",
18751875
"type": "list",
18761876
"items": {"type": "string"},
1877-
"options": ["text", "image", "tool_use"],
1878-
"labels": ["文本", "图像", "工具使用"],
1877+
"options": ["text", "image", "audio", "tool_use"],
1878+
"labels": ["文本", "图像", "音频", "工具使用"],
18791879
"render_type": "checkbox",
18801880
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
18811881
},

astrbot/core/message/components.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ class ComponentType(str, Enum):
6464
Music = "Music"
6565
Json = "Json"
6666
Unknown = "Unknown"
67-
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
6867

6968

7069
class BaseMessageComponent(BaseModel):
@@ -91,7 +90,6 @@ async def to_dict(self) -> dict:
9190
class Plain(BaseMessageComponent):
9291
type: ComponentType = ComponentType.Plain
9392
text: str
94-
convert: bool | None = True
9593

9694
def __init__(self, text: str, convert: bool = True, **_) -> None:
9795
super().__init__(text=text, convert=convert, **_)
@@ -114,11 +112,7 @@ def __init__(self, **_) -> None:
114112
class Record(BaseMessageComponent):
115113
type: ComponentType = ComponentType.Record
116114
file: str | None = ""
117-
magic: bool | None = False
118115
url: str | None = ""
119-
cache: bool | None = True
120-
proxy: bool | None = True
121-
timeout: int | None = 0
122116
# Original text content (e.g. TTS source text), used as caption in fallback scenarios
123117
text: str | None = None
124118
# 额外
@@ -224,7 +218,6 @@ class Video(BaseMessageComponent):
224218
type: ComponentType = ComponentType.Video
225219
file: str
226220
cover: str | None = ""
227-
c: int | None = 2
228221
# 额外
229222
path: str | None = ""
230223

@@ -401,14 +394,9 @@ class Image(BaseMessageComponent):
401394
type: ComponentType = ComponentType.Image
402395
file: str | None = ""
403396
_type: str | None = ""
404-
subType: int | None = 0
405397
url: str | None = ""
406-
cache: bool | None = True
407-
id: int | None = 40000
408-
c: int | None = 2
409398
# 额外
410399
path: str | None = ""
411-
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
412400

413401
def __init__(self, file: str | None, **_) -> None:
414402
super().__init__(file=file, **_)
@@ -839,16 +827,6 @@ async def to_dict(self):
839827
}
840828

841829

842-
class WechatEmoji(BaseMessageComponent):
843-
type: ComponentType = ComponentType.WechatEmoji
844-
md5: str | None = ""
845-
md5_len: int | None = 0
846-
cdnurl: str | None = ""
847-
848-
def __init__(self, **_) -> None:
849-
super().__init__(**_)
850-
851-
852830
ComponentTypes = {
853831
# Basic Message Segments
854832
"plain": Plain,
@@ -874,5 +852,4 @@ def __init__(self, **_) -> None:
874852
"nodes": Nodes,
875853
"json": Json,
876854
"unknown": Unknown,
877-
"WechatEmoji": WechatEmoji,
878855
}

astrbot/core/pipeline/preprocess_stage/stage.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from astrbot.core import logger
77
from astrbot.core.message.components import Image, Plain, Record
88
from astrbot.core.platform.astr_message_event import AstrMessageEvent
9+
from astrbot.core.utils.media_utils import ensure_wav
910

1011
from ..context import PipelineContext
1112
from ..stage import Stage, register_stage
@@ -64,6 +65,21 @@ async def process(
6465
logger.debug(f"路径映射: {url} -> {component.url}")
6566
message_chain[idx] = component
6667

68+
# In here, we convert all Record components to wav format and update the file path.
69+
message_chain = event.get_messages()
70+
for idx, component in enumerate(message_chain):
71+
if isinstance(component, Record):
72+
try:
73+
original_path = await component.convert_to_file_path()
74+
record_path = await ensure_wav(original_path)
75+
if record_path != original_path:
76+
event.track_temporary_local_file(record_path)
77+
component.file = record_path
78+
component.path = record_path
79+
message_chain[idx] = component
80+
except Exception as e:
81+
logger.warning(f"Voice processing failed: {e}")
82+
6783
# STT
6884
if self.stt_settings.get("enable", False):
6985
# TODO: 独立

astrbot/core/pipeline/process_stage/method/agent_sub_stages/internal.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
MainAgentBuildResult,
1414
build_main_agent,
1515
)
16-
from astrbot.core.message.components import File, Image
16+
from astrbot.core.message.components import File, Image, Record, Video
1717
from astrbot.core.message.message_event_result import (
1818
MessageChain,
1919
MessageEventResult,
@@ -153,7 +153,8 @@ async def process(
153153
has_provider_request = event.get_extra("provider_request") is not None
154154
has_valid_message = bool(event.message_str and event.message_str.strip())
155155
has_media_content = any(
156-
isinstance(comp, Image | File) for comp in event.message_obj.message
156+
isinstance(comp, (Image, File, Record, Video))
157+
for comp in event.message_obj.message
157158
)
158159

159160
if (

astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
1919
from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
20-
from astrbot.core.message.components import Image
20+
from astrbot.core.message.components import Image, Record
2121
from astrbot.core.message.message_event_result import (
2222
MessageChain,
2323
MessageEventResult,
@@ -317,8 +317,11 @@ async def process(
317317
if isinstance(comp, Image):
318318
image_path = await comp.convert_to_base64()
319319
req.image_urls.append(image_path)
320+
elif isinstance(comp, Record):
321+
audio_path = await comp.convert_to_file_path()
322+
req.audio_urls.append(audio_path)
320323

321-
if not req.prompt and not req.image_urls:
324+
if not req.prompt and not req.image_urls and not req.audio_urls:
322325
return
323326

324327
custom_error_message = await self._resolve_persona_custom_error_message(event)

astrbot/core/pipeline/respond/stage.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ class RespondStage(Stage):
3232
Comp.Node: lambda comp: bool(comp.content), # 转发节点
3333
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
3434
Comp.File: lambda comp: bool(comp.file_ or comp.url),
35-
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
3635
Comp.Json: lambda comp: bool(comp.data), # Json 卡片
3736
Comp.Share: lambda comp: bool(comp.url) or bool(comp.title),
3837
Comp.Music: lambda comp: (

astrbot/core/platform/astr_message_event.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def request_llm(
414414
tool_set: ToolSet | None = None,
415415
session_id: str = "",
416416
image_urls: list[str] | None = None,
417+
audio_urls: list[str] | None = None,
417418
contexts: list | None = None,
418419
system_prompt: str = "",
419420
conversation: Conversation | None = None,
@@ -432,6 +433,8 @@ def request_llm(
432433
433434
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
434435
436+
audio_urls: 音频 URL 列表,也支持本地路径。
437+
435438
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。如果同时传入了 conversation,将会忽略 conversation。
436439
437440
func_tool_manager: [Deprecated] 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。已过时,请使用 tool_set 参数代替。
@@ -441,6 +444,8 @@ def request_llm(
441444
"""
442445
if image_urls is None:
443446
image_urls = []
447+
if audio_urls is None:
448+
audio_urls = []
444449
if contexts is None:
445450
contexts = []
446451
if len(contexts) > 0 and conversation:
@@ -450,6 +455,7 @@ def request_llm(
450455
prompt=prompt,
451456
session_id=session_id,
452457
image_urls=image_urls,
458+
audio_urls=audio_urls,
453459
# func_tool=func_tool_manager,
454460
func_tool=tool_set,
455461
contexts=contexts,

0 commit comments

Comments
 (0)