Skip to content

Commit 0de6d0e

Browse files
authored
Merge pull request #1256 from Raven95676/better-stream
perf: 为不支持流式输出的平台提供fallback。
2 parents 9fedaa9 + 4c546f2 commit 0de6d0e

11 files changed

Lines changed: 117 additions & 46 deletions

File tree

astrbot/core/config/default.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"max_context_length": -1,
5454
"dequeue_context_length": 1,
5555
"streaming_response": False,
56+
"streaming_segmented": False,
5657
},
5758
"provider_stt_settings": {
5859
"enable": False,
@@ -1028,6 +1029,11 @@
10281029
"type": "bool",
10291030
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
10301031
},
1032+
"streaming_segmented": {
1033+
"description": "不支持流式回复的平台分段输出",
1034+
"type": "bool",
1035+
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
1036+
},
10311037
},
10321038
},
10331039
"persona": {

astrbot/core/pipeline/respond/stage.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,12 @@ async def process(
146146

147147
if result.result_content_type == ResultContentType.STREAMING_RESULT:
148148
# 流式结果直接交付平台适配器处理
149+
use_fallback = self.config.get("provider_settings", {}).get(
150+
"streaming_segmented", False
151+
)
149152
logger.info(f"应用流式输出({event.get_platform_name()})")
150153
await event._pre_send()
151-
await event.send_streaming(result.async_stream)
154+
await event.send_streaming(result.async_stream, use_fallback)
152155
await event._post_send()
153156
return
154157
elif len(result.chain) > 0:
@@ -159,7 +162,7 @@ async def process(
159162
# 支持 File 消息段的路径映射。
160163
component.file = path_Mapping(mappings, component.file)
161164
event.get_result().chain[idx] = component
162-
165+
163166
await event._pre_send()
164167

165168
# 检查消息链是否为空

astrbot/core/platform/astr_message_event.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import abc
22
import asyncio
3+
import re
34
import hashlib
45
import uuid
56
from dataclasses import dataclass
@@ -207,9 +208,26 @@ def is_admin(self) -> bool:
207208
"""
208209
return self.role == "admin"
209210

210-
async def send_streaming(self, generator: AsyncGenerator[MessageChain, None]):
211+
async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
212+
"""
213+
将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。
214+
"""
215+
while True:
216+
match = re.search(pattern, buffer)
217+
if not match:
218+
break
219+
matched_text = match.group()
220+
await self.send(MessageChain([Plain(matched_text)]))
221+
buffer = buffer[match.end() :]
222+
await asyncio.sleep(1.5) # 限速
223+
return buffer
224+
225+
async def send_streaming(
226+
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
227+
):
211228
"""发送流式消息到消息平台,使用异步生成器。
212229
目前仅支持: telegram,qq official 私聊。
230+
Fallback仅支持 aiocqhttp, gewechat。
213231
"""
214232
asyncio.create_task(
215233
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)

astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import asyncio
2-
import typing
2+
import re
3+
from typing import AsyncGenerator, Dict, List
4+
from aiocqhttp import CQHttp
35
from astrbot.api.event import AstrMessageEvent, MessageChain
6+
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record
47
from astrbot.api.platform import Group, MessageMember
5-
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
6-
from aiocqhttp import CQHttp
78

89

910
class AiocqhttpMessageEvent(AstrMessageEvent):
@@ -82,18 +83,39 @@ async def send(self, message: MessageChain):
8283

8384
await super().send(message)
8485

85-
async def send_streaming(self, generator):
86-
buffer = None
87-
async for chain in generator:
86+
async def send_streaming(
87+
self, generator: AsyncGenerator, use_fallback: bool = False
88+
):
89+
if not use_fallback:
90+
buffer = None
91+
async for chain in generator:
92+
if not buffer:
93+
buffer = chain
94+
else:
95+
buffer.chain.extend(chain.chain)
8896
if not buffer:
89-
buffer = chain
90-
else:
91-
buffer.chain.extend(chain.chain)
92-
if not buffer:
93-
return
94-
buffer.squash_plain()
95-
await self.send(buffer)
96-
return await super().send_streaming(generator)
97+
return
98+
buffer.squash_plain()
99+
await self.send(buffer)
100+
return await super().send_streaming(generator, use_fallback)
101+
102+
buffer = ""
103+
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
104+
105+
async for chain in generator:
106+
if isinstance(chain, MessageChain):
107+
for comp in chain.chain:
108+
if isinstance(comp, Plain):
109+
buffer += comp.text
110+
if any(p in buffer for p in "。?!~…"):
111+
buffer = await self.process_buffer(buffer, pattern)
112+
else:
113+
await self.send(MessageChain(chain=[comp]))
114+
await asyncio.sleep(1.5) # 限速
115+
116+
if buffer.strip():
117+
await self.send(MessageChain([Plain(buffer)]))
118+
return await super().send_streaming(generator, use_fallback)
97119

98120
async def get_group(self, group_id=None, **kwargs):
99121
if isinstance(group_id, str) and group_id.isdigit():
@@ -108,7 +130,7 @@ async def get_group(self, group_id=None, **kwargs):
108130
group_id=group_id,
109131
)
110132

111-
members: typing.List[typing.Dict] = await self.bot.call_action(
133+
members: List[Dict] = await self.bot.call_action(
112134
"get_group_member_list",
113135
group_id=group_id,
114136
)

astrbot/core/platform/sources/dingtalk/dingtalk_event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ async def send(self, message: MessageChain):
6161
await self.send_with_client(self.client, message)
6262
await super().send(message)
6363

64-
async def send_streaming(self, generator):
64+
async def send_streaming(self, generator, use_fallback: bool = False):
6565
buffer = None
6666
async for chain in generator:
6767
if not buffer:
@@ -72,4 +72,4 @@ async def send_streaming(self, generator):
7272
return
7373
buffer.squash_plain()
7474
await self.send(buffer)
75-
return await super().send_streaming(generator)
75+
return await super().send_streaming(generator, use_fallback)

astrbot/core/platform/sources/gewechat/gewechat_event.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
import asyncio
2+
import re
13
import wave
24
import uuid
35
import traceback
46
import os
57

8+
from typing import AsyncGenerator
69
from astrbot.core.utils.io import save_temp_img, download_file
710
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
811
from astrbot.api import logger
@@ -217,15 +220,36 @@ async def get_group(self, group_id=None, **kwargs):
217220
members=members,
218221
)
219222

220-
async def send_streaming(self, generator):
221-
buffer = None
222-
async for chain in generator:
223+
async def send_streaming(
224+
self, generator: AsyncGenerator, use_fallback: bool = False
225+
):
226+
if not use_fallback:
227+
buffer = None
228+
async for chain in generator:
229+
if not buffer:
230+
buffer = chain
231+
else:
232+
buffer.chain.extend(chain.chain)
223233
if not buffer:
224-
buffer = chain
225-
else:
226-
buffer.chain.extend(chain.chain)
227-
if not buffer:
228-
return
229-
buffer.squash_plain()
230-
await self.send(buffer)
231-
return await super().send_streaming(generator)
234+
return
235+
buffer.squash_plain()
236+
await self.send(buffer)
237+
return await super().send_streaming(generator, use_fallback)
238+
239+
buffer = ""
240+
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
241+
242+
async for chain in generator:
243+
if isinstance(chain, MessageChain):
244+
for comp in chain.chain:
245+
if isinstance(comp, Plain):
246+
buffer += comp.text
247+
if any(p in buffer for p in "。?!~…"):
248+
buffer = await self.process_buffer(buffer, pattern)
249+
else:
250+
await self.send(MessageChain(chain=[comp]))
251+
await asyncio.sleep(1.5) # 限速
252+
253+
if buffer.strip():
254+
await self.send(MessageChain([Plain(buffer)]))
255+
return await super().send_streaming(generator, use_fallback)

astrbot/core/platform/sources/lark/lark_event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ async def send(self, message: MessageChain):
104104

105105
await super().send(message)
106106

107-
async def send_streaming(self, generator):
107+
async def send_streaming(self, generator, use_fallback: bool = False):
108108
buffer = None
109109
async for chain in generator:
110110
if not buffer:
@@ -115,4 +115,4 @@ async def send_streaming(self, generator):
115115
return
116116
buffer.squash_plain()
117117
await self.send(buffer)
118-
return await super().send_streaming(generator)
118+
return await super().send_streaming(generator, use_fallback)

astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def send(self, message: MessageChain):
3333
else:
3434
self.send_buffer.chain.extend(message.chain)
3535

36-
async def send_streaming(self, generator):
36+
async def send_streaming(self, generator, use_fallback: bool = False):
3737
"""流式输出仅支持消息列表私聊"""
3838
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
3939
last_edit_time = 0 # 上次编辑消息的时间
@@ -66,7 +66,7 @@ async def send_streaming(self, generator):
6666
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
6767
self.send_buffer = None
6868

69-
return await super().send_streaming(generator)
69+
return await super().send_streaming(generator, use_fallback)
7070

7171
async def _post_send(self, stream: dict = None):
7272
if not self.send_buffer:
@@ -97,7 +97,7 @@ async def _post_send(self, stream: dict = None):
9797
"msg_id": self.message_obj.message_id,
9898
}
9999

100-
if not isinstance(source, (botpy.message.Message,botpy.message.DirectMessage)):
100+
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
101101
payload["msg_seq"] = random.randint(1, 10000)
102102

103103
match type(source):

astrbot/core/platform/sources/telegram/tg_event.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ async def send(self, message: MessageChain):
9191
await self.send_with_client(self.client, message, self.get_sender_id())
9292
await super().send(message)
9393

94-
async def send_streaming(self, generator):
94+
async def send_streaming(self, generator, use_fallback: bool = False):
9595
message_thread_id = None
9696

9797
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -183,16 +183,14 @@ async def send_streaming(self, generator):
183183
text=markdown_text,
184184
chat_id=payload["chat_id"],
185185
message_id=message_id,
186-
parse_mode="MarkdownV2"
186+
parse_mode="MarkdownV2",
187187
)
188188
except Exception as e:
189189
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
190190
await self.client.edit_message_text(
191-
text=delta,
192-
chat_id=payload["chat_id"],
193-
message_id=message_id
191+
text=delta, chat_id=payload["chat_id"], message_id=message_id
194192
)
195193
except Exception as e:
196194
logger.warning(f"编辑消息失败(streaming): {e!s}")
197195

198-
return await super().send_streaming(generator)
196+
return await super().send_streaming(generator, use_fallback)

astrbot/core/platform/sources/webchat/webchat_event.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ async def send(self, message: MessageChain):
106106
)
107107
await super().send(message)
108108

109-
async def send_streaming(self, generator):
109+
async def send_streaming(self, generator, use_fallback: bool = False):
110110
final_data = ""
111111
async for chain in generator:
112112
final_data += await WebChatMessageEvent._send(
@@ -121,4 +121,4 @@ async def send_streaming(self, generator):
121121
"cid": self.session_id.split("!")[-1],
122122
}
123123
)
124-
await super().send_streaming(generator)
124+
await super().send_streaming(generator, use_fallback)

0 commit comments

Comments
 (0)