Skip to content

Commit b802ede

Browse files
committed
Merge dev and fix conflicts
2 parents be22a99 + b0b6816 commit b802ede

45 files changed

Lines changed: 1455 additions & 361 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

astrbot/core/astr_main_agent.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from astrbot.core.astr_agent_run_util import AgentRunner
2222
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
2323
from astrbot.core.conversation_mgr import Conversation
24-
from astrbot.core.message.components import File, Image, Reply
24+
from astrbot.core.message.components import File, Image, Record, Reply
2525
from astrbot.core.persona_error_reply import (
2626
extract_persona_custom_error_message_from_persona,
2727
set_persona_custom_error_message_on_event,
@@ -419,6 +419,18 @@ def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> No
419419
)
420420

421421

422+
def _append_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
423+
req.extra_user_content_parts.append(
424+
TextPart(text=f"[Audio Attachment: path {audio_path}]")
425+
)
426+
427+
428+
def _append_quoted_audio_attachment(req: ProviderRequest, audio_path: str) -> None:
429+
req.extra_user_content_parts.append(
430+
TextPart(text=f"[Audio Attachment in quoted message: path {audio_path}]")
431+
)
432+
433+
422434
def _get_quoted_message_parser_settings(
423435
provider_settings: dict[str, object] | None,
424436
) -> QuotedMessageParserSettings:
@@ -704,12 +716,25 @@ def _modalities_fix(provider: Provider, req: ProviderRequest) -> None:
704716
"Provider %s does not support image, using placeholder.", provider
705717
)
706718
image_count = len(req.image_urls)
707-
placeholder = " ".join(["[图片]"] * image_count)
719+
placeholder = " ".join(["[Image]"] * image_count)
708720
if req.prompt:
709721
req.prompt = f"{placeholder} {req.prompt}"
710722
else:
711723
req.prompt = placeholder
712724
req.image_urls = []
725+
if req.audio_urls:
726+
provider_cfg = provider.provider_config.get("modalities", ["audio"])
727+
if "audio" not in provider_cfg:
728+
logger.debug(
729+
"Provider %s does not support audio, using placeholder.", provider
730+
)
731+
audio_count = len(req.audio_urls)
732+
placeholder = " ".join(["[Audio]"] * audio_count)
733+
if req.prompt:
734+
req.prompt = f"{placeholder} {req.prompt}"
735+
else:
736+
req.prompt = placeholder
737+
req.audio_urls = []
713738
if req.func_tool:
714739
provider_cfg = provider.provider_config.get("modalities", ["tool_use"])
715740
if "tool_use" not in provider_cfg:
@@ -730,11 +755,13 @@ def _sanitize_context_by_modalities(
730755
if not modalities or not isinstance(modalities, list):
731756
return
732757
supports_image = bool("image" in modalities)
758+
supports_audio = bool("audio" in modalities)
733759
supports_tool_use = bool("tool_use" in modalities)
734-
if supports_image and supports_tool_use:
760+
if supports_image and supports_audio and supports_tool_use:
735761
return
736762
sanitized_contexts: list[dict] = []
737763
removed_image_blocks = 0
764+
removed_audio_blocks = 0
738765
removed_tool_messages = 0
739766
removed_tool_calls = 0
740767
for msg in req.contexts:
@@ -753,20 +780,28 @@ def _sanitize_context_by_modalities(
753780
removed_tool_calls += 1
754781
new_msg.pop("tool_calls", None)
755782
new_msg.pop("tool_call_id", None)
756-
if not supports_image:
783+
784+
if not supports_image or not supports_audio:
757785
content = new_msg.get("content")
758786
if isinstance(content, list):
759787
filtered_parts: list = []
760-
removed_any_image = False
788+
removed_any_multimodal = False
761789
for part in content:
762790
if isinstance(part, dict):
763791
part_type = str(part.get("type", "")).lower()
764-
if part_type in {"image_url", "image"}:
765-
removed_any_image = True
792+
if not supports_image and part_type in {"image_url", "image"}:
793+
removed_any_multimodal = True
766794
removed_image_blocks += 1
767795
continue
796+
if not supports_audio and part_type in {
797+
"audio_url",
798+
"input_audio",
799+
}:
800+
removed_any_multimodal = True
801+
removed_audio_blocks += 1
802+
continue
768803
filtered_parts.append(part)
769-
if removed_any_image:
804+
if removed_any_multimodal:
770805
new_msg["content"] = filtered_parts
771806
if role == "assistant":
772807
content = new_msg.get("content")
@@ -777,10 +812,19 @@ def _sanitize_context_by_modalities(
777812
if isinstance(content, str) and (not content.strip()):
778813
continue
779814
sanitized_contexts.append(new_msg)
780-
if removed_image_blocks or removed_tool_messages or removed_tool_calls:
815+
816+
if (
817+
removed_image_blocks
818+
or removed_audio_blocks
819+
or removed_tool_messages
820+
or removed_tool_calls
821+
):
781822
logger.debug(
782-
"sanitize_context_by_modalities applied: removed_image_blocks=%s, removed_tool_messages=%s, removed_tool_calls=%s",
823+
"sanitize_context_by_modalities applied: "
824+
"removed_image_blocks=%s, removed_audio_blocks=%s, "
825+
"removed_tool_messages=%s, removed_tool_calls=%s",
783826
removed_image_blocks,
827+
removed_audio_blocks,
784828
removed_tool_messages,
785829
removed_tool_calls,
786830
)
@@ -969,6 +1013,7 @@ async def build_main_agent(
9691013
req = ProviderRequest()
9701014
req.prompt = ""
9711015
req.image_urls = []
1016+
req.audio_urls = []
9721017
if sel_model := event.get_extra("selected_model"):
9731018
req.model = sel_model
9741019
if config.provider_wake_prefix and (
@@ -988,6 +1033,10 @@ async def build_main_agent(
9881033
req.extra_user_content_parts.append(
9891034
TextPart(text=f"[Image Attachment: path {image_path}]")
9901035
)
1036+
elif isinstance(comp, Record):
1037+
audio_path = await comp.convert_to_file_path()
1038+
req.audio_urls.append(audio_path)
1039+
_append_audio_attachment(req, audio_path)
9911040
elif isinstance(comp, File):
9921041
file_path = await comp.get_file()
9931042
file_name = comp.name or os.path.basename(file_path)
@@ -1017,6 +1066,10 @@ async def build_main_agent(
10171066
event.track_temporary_local_file(image_path)
10181067
req.image_urls.append(image_path)
10191068
_append_quoted_image_attachment(req, image_path)
1069+
elif isinstance(reply_comp, Record):
1070+
audio_path = await reply_comp.convert_to_file_path()
1071+
req.audio_urls.append(audio_path)
1072+
_append_quoted_audio_attachment(req, audio_path)
10201073
elif isinstance(reply_comp, File):
10211074
file_path = await reply_comp.get_file()
10221075
file_name = reply_comp.name or os.path.basename(file_path)
@@ -1074,12 +1127,15 @@ async def build_main_agent(
10741127
if isinstance(req.contexts, str):
10751128
req.contexts = json.loads(req.contexts)
10761129
req.image_urls = normalize_and_dedupe_strings(req.image_urls)
1130+
req.audio_urls = normalize_and_dedupe_strings(req.audio_urls)
1131+
10771132
if config.file_extract_enabled:
10781133
try:
10791134
await _apply_file_extract(event, req, config)
10801135
except Exception as exc:
10811136
logger.error("Error occurred while applying file extract: %s", exc)
1082-
if not req.prompt and (not req.image_urls):
1137+
1138+
if not req.prompt and not req.image_urls and not req.audio_urls:
10831139
if not event.get_group_id() and req.extra_user_content_parts:
10841140
req.prompt = "<attachment>"
10851141
else:

astrbot/core/config/default.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,6 +1247,18 @@ class ChatProviderTemplate(TypedDict):
12471247
"proxy": "",
12481248
"custom_headers": {},
12491249
},
1250+
"LongCat": {
1251+
"id": "longcat",
1252+
"provider": "longcat",
1253+
"type": "longcat_chat_completion",
1254+
"provider_type": "chat_completion",
1255+
"enable": True,
1256+
"key": [],
1257+
"api_base": "https://api.longcat.chat/openai",
1258+
"timeout": 120,
1259+
"proxy": "",
1260+
"custom_headers": {},
1261+
},
12501262
"AIHubMix": {
12511263
"id": "aihubmix",
12521264
"provider": "aihubmix",
@@ -1761,6 +1773,7 @@ class ChatProviderTemplate(TypedDict):
17611773
"enable": True,
17621774
"rerank_api_key": "",
17631775
"rerank_api_base": "http://127.0.0.1:8000",
1776+
"rerank_api_suffix": "/v1/rerank",
17641777
"rerank_model": "BAAI/bge-reranker-base",
17651778
"timeout": 20,
17661779
},
@@ -1789,6 +1802,19 @@ class ChatProviderTemplate(TypedDict):
17891802
"return_documents": False,
17901803
"instruct": "",
17911804
},
1805+
"NVIDIA Rerank": {
1806+
"id": "nvidia_rerank",
1807+
"type": "nvidia_rerank",
1808+
"provider": "nvidia",
1809+
"provider_type": "rerank",
1810+
"enable": True,
1811+
"nvidia_rerank_api_key": "",
1812+
"nvidia_rerank_api_base": "https://ai.api.nvidia.com/v1/retrieval",
1813+
"nvidia_rerank_model": "nv-rerank-qa-mistral-4b:1",
1814+
"nvidia_rerank_model_endpoint": "/reranking",
1815+
"timeout": 20,
1816+
"nvidia_rerank_truncate": "",
1817+
},
17921818
"Xinference STT": {
17931819
"id": "xinference_stt",
17941820
"type": "xinference_stt",
@@ -1826,7 +1852,12 @@ class ChatProviderTemplate(TypedDict):
18261852
"rerank_api_base": {
18271853
"description": "重排序模型 API Base URL",
18281854
"type": "string",
1829-
"hint": "AstrBot 会在请求时在末尾加上 /v1/rerank。",
1855+
"hint": "最终请求路径由 Base URL 和路径后缀拼接而成(默认为 /v1/rerank)。",
1856+
},
1857+
"rerank_api_suffix": {
1858+
"description": "API URL 路径后缀",
1859+
"type": "string",
1860+
"hint": "追加到 base_url 后的路径,如 /v1/rerank。留空则不追加。",
18301861
},
18311862
"rerank_api_key": {
18321863
"description": "API Key",
@@ -1852,12 +1883,40 @@ class ChatProviderTemplate(TypedDict):
18521883
"type": "bool",
18531884
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
18541885
},
1886+
"nvidia_rerank_api_base": {
1887+
"description": "API Base URL",
1888+
"type": "string",
1889+
},
1890+
"nvidia_rerank_api_key": {
1891+
"description": "API Key",
1892+
"type": "string",
1893+
},
1894+
"nvidia_rerank_model": {
1895+
"description": "重排序模型名称",
1896+
"type": "string",
1897+
"hint": "请参照NVIDIA Docs中模型名称填写。",
1898+
},
1899+
"nvidia_rerank_model_endpoint": {
1900+
"description": "自定义模型端点",
1901+
"type": "string",
1902+
"hint": "自定义URL末尾端点,默认为 /reranking",
1903+
},
1904+
"nvidia_rerank_truncate": {
1905+
"description": "文本截断策略",
1906+
"type": "string",
1907+
"hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。",
1908+
"options": [
1909+
"",
1910+
"NONE",
1911+
"END",
1912+
],
1913+
},
18551914
"modalities": {
18561915
"description": "模型能力",
18571916
"type": "list",
18581917
"items": {"type": "string"},
1859-
"options": ["text", "image", "tool_use"],
1860-
"labels": ["文本", "图像", "工具使用"],
1918+
"options": ["text", "image", "audio", "tool_use"],
1919+
"labels": ["文本", "图像", "音频", "工具使用"],
18611920
"render_type": "checkbox",
18621921
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
18631922
},

astrbot/core/message/components.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,6 @@ class ComponentType(str, Enum):
7474
Music = "Music"
7575
Json = "Json"
7676
Unknown = "Unknown"
77-
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
78-
# Discord-specific component types
79-
DiscordEmbed = "DiscordEmbed"
80-
DiscordButton = "DiscordButton"
81-
DiscordReference = "DiscordReference"
82-
DiscordView = "DiscordView"
8377

8478

8579
class BaseMessageComponent(BaseModel):
@@ -106,7 +100,6 @@ async def to_dict(self) -> dict:
106100
class Plain(BaseMessageComponent):
107101
type: ComponentType = ComponentType.Plain
108102
text: str
109-
convert: bool | None = True
110103

111104
def __init__(self, text: str, convert: bool = True, **_) -> None:
112105
super().__init__(text=text, convert=convert, **_)
@@ -129,11 +122,7 @@ def __init__(self, **_) -> None:
129122
class Record(BaseMessageComponent):
130123
type: ComponentType = ComponentType.Record
131124
file: str | None = ""
132-
magic: bool | None = False
133125
url: str | None = ""
134-
cache: bool | None = True
135-
proxy: bool | None = True
136-
timeout: int | None = 0
137126
# Original text content (e.g. TTS source text), used as caption in fallback scenarios
138127
text: str | None = None
139128
# 额外
@@ -239,7 +228,6 @@ class Video(BaseMessageComponent):
239228
type: ComponentType = ComponentType.Video
240229
file: str
241230
cover: str | None = ""
242-
c: int | None = 2
243231
# 额外
244232
path: str | None = ""
245233

@@ -416,14 +404,9 @@ class Image(BaseMessageComponent):
416404
type: ComponentType = ComponentType.Image
417405
file: str | None = ""
418406
_type: str | None = ""
419-
subType: int | None = 0
420407
url: str | None = ""
421-
cache: bool | None = True
422-
id: int | None = 40000
423-
c: int | None = 2
424408
# 额外
425409
path: str | None = ""
426-
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
427410

428411
def __init__(self, file: str | None, **_) -> None:
429412
super().__init__(file=file, **_)
@@ -854,16 +837,6 @@ async def to_dict(self):
854837
}
855838

856839

857-
class WechatEmoji(BaseMessageComponent):
858-
type: ComponentType = ComponentType.WechatEmoji
859-
md5: str | None = ""
860-
md5_len: int | None = 0
861-
cdnurl: str | None = ""
862-
863-
def __init__(self, **_) -> None:
864-
super().__init__(**_)
865-
866-
867840
ComponentTypes = {
868841
# Basic Message Segments
869842
"plain": Plain,
@@ -889,5 +862,4 @@ def __init__(self, **_) -> None:
889862
"nodes": Nodes,
890863
"json": Json,
891864
"unknown": Unknown,
892-
"WechatEmoji": WechatEmoji,
893865
}

astrbot/core/pipeline/preprocess_stage/stage.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from astrbot.core.pipeline.context import PipelineContext
88
from astrbot.core.pipeline.stage import Stage, register_stage
99
from astrbot.core.platform.astr_message_event import AstrMessageEvent
10+
from astrbot.core.utils.media_utils import ensure_wav
1011

1112

1213
@register_stage
@@ -62,6 +63,21 @@ async def process(
6263
logger.debug(f"路径映射: {url} -> {component.url}")
6364
message_chain[idx] = component
6465

66+
# In here, we convert all Record components to wav format and update the file path.
67+
message_chain = event.get_messages()
68+
for idx, component in enumerate(message_chain):
69+
if isinstance(component, Record):
70+
try:
71+
original_path = await component.convert_to_file_path()
72+
record_path = await ensure_wav(original_path)
73+
if record_path != original_path:
74+
event.track_temporary_local_file(record_path)
75+
component.file = record_path
76+
component.path = record_path
77+
message_chain[idx] = component
78+
except Exception as e:
79+
logger.warning(f"Voice processing failed: {e}")
80+
6581
# STT
6682
if self.stt_settings.get("enable", False):
6783
# TODO: 独立

0 commit comments

Comments
 (0)