Skip to content

Commit a8f874b

Browse files
authored
fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 (#2757)
1 parent 9d9917e commit a8f874b

1 file changed

Lines changed: 103 additions & 77 deletions

File tree

astrbot/core/pipeline/respond/stage.py

Lines changed: 103 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
import random
22
import asyncio
33
import math
4-
import traceback
54
import astrbot.core.message.components as Comp
65
from typing import Union, AsyncGenerator
76
from ..stage import register_stage, Stage
8-
from ..context import PipelineContext
7+
from ..context import PipelineContext, call_event_hook
98
from astrbot.core.platform.astr_message_event import AstrMessageEvent
109
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
1110
from astrbot.core import logger
12-
from astrbot.core.message.message_event_result import BaseMessageComponent
13-
from astrbot.core.star.star_handler import star_handlers_registry, EventType
14-
from astrbot.core.star.star import star_map
11+
from astrbot.core.message.components import BaseMessageComponent, ComponentType
12+
from astrbot.core.star.star_handler import EventType
1513
from astrbot.core.utils.path_util import path_Mapping
1614
from astrbot.core.utils.session_lock import session_lock_manager
1715

@@ -114,6 +112,43 @@ async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
114112
# 如果所有组件都为空
115113
return True
116114

115+
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
116+
"""检查是否需要分段回复"""
117+
if not self.enable_seg:
118+
return False
119+
120+
if self.only_llm_result and not event.get_result().is_llm_result():
121+
return False
122+
123+
if event.get_platform_name() in [
124+
"qq_official",
125+
"weixin_official_account",
126+
"dingtalk",
127+
]:
128+
return False
129+
130+
return True
131+
132+
def _extract_comp(
133+
self,
134+
raw_chain: list[BaseMessageComponent],
135+
extract_types: set[ComponentType],
136+
modify_raw_chain: bool = True,
137+
):
138+
extracted = []
139+
if modify_raw_chain:
140+
remaining = []
141+
for comp in raw_chain:
142+
if comp.type in extract_types:
143+
extracted.append(comp)
144+
else:
145+
remaining.append(comp)
146+
raw_chain[:] = remaining
147+
else:
148+
extracted = [comp for comp in raw_chain if comp.type in extract_types]
149+
150+
return extracted
151+
117152
async def process(
118153
self, event: AstrMessageEvent
119154
) -> Union[None, AsyncGenerator[None, None]]:
@@ -123,7 +158,14 @@ async def process(
123158
if result.result_content_type == ResultContentType.STREAMING_FINISH:
124159
return
125160

161+
logger.info(
162+
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
163+
)
164+
126165
if result.result_content_type == ResultContentType.STREAMING_RESULT:
166+
if result.async_stream is None:
167+
logger.warning("async_stream 为空,跳过发送。")
168+
return
127169
# 流式结果直接交付平台适配器处理
128170
use_fallback = self.config.get("provider_settings", {}).get(
129171
"streaming_segmented", False
@@ -148,87 +190,71 @@ async def process(
148190
except Exception as e:
149191
logger.warning(f"空内容检查异常: {e}")
150192

151-
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
152-
non_record_comps = [
153-
c for c in result.chain if not isinstance(c, Comp.Record)
154-
]
155-
156-
if (
157-
self.enable_seg
158-
and (
159-
(self.only_llm_result and result.is_llm_result())
160-
or not self.only_llm_result
193+
# 发送消息链
194+
# Record 需要强制单独发送
195+
need_separately = {ComponentType.Record}
196+
if self.is_seg_reply_required(event):
197+
header_comps = self._extract_comp(
198+
result.chain,
199+
{ComponentType.Reply, ComponentType.At},
200+
modify_raw_chain=True,
161201
)
162-
and event.get_platform_name()
163-
not in ["qq_official", "weixin_official_account", "dingtalk"]
164-
):
165-
decorated_comps = []
166-
if self.reply_with_mention:
167-
for comp in result.chain:
168-
if isinstance(comp, Comp.At):
169-
decorated_comps.append(comp)
170-
result.chain.remove(comp)
171-
break
172-
if self.reply_with_quote:
173-
for comp in result.chain:
174-
if isinstance(comp, Comp.Reply):
175-
decorated_comps.append(comp)
176-
result.chain.remove(comp)
177-
break
178-
179-
# leverage lock to guarentee the order of message sending among different events
202+
if not result.chain or len(result.chain) == 0:
203+
# may fix #2670
204+
logger.warning(
205+
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
206+
)
207+
return
180208
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
181-
for rcomp in record_comps:
182-
i = await self._calc_comp_interval(rcomp)
183-
await asyncio.sleep(i)
184-
try:
185-
await event.send(MessageChain([rcomp]))
186-
except Exception as e:
187-
logger.error(f"发送消息失败: {e} chain: {result.chain}")
188-
break
189-
# 分段回复
190-
for comp in non_record_comps:
209+
for comp in result.chain:
191210
i = await self._calc_comp_interval(comp)
192211
await asyncio.sleep(i)
193212
try:
194-
await event.send(MessageChain([*decorated_comps, comp]))
195-
decorated_comps = [] # 清空已发送的装饰组件
213+
if comp.type in need_separately:
214+
await event.send(MessageChain([comp]))
215+
else:
216+
await event.send(MessageChain([*header_comps, comp]))
217+
header_comps.clear()
196218
except Exception as e:
197-
logger.error(f"发送消息失败: {e} chain: {result.chain}")
198-
break
219+
logger.error(
220+
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
221+
exc_info=True,
222+
)
199223
else:
200-
for rcomp in record_comps:
224+
if all(
225+
comp.type in {ComponentType.Reply, ComponentType.At}
226+
for comp in result.chain
227+
):
228+
# may fix #2670
229+
logger.warning(
230+
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
231+
)
232+
return
233+
sep_comps = self._extract_comp(
234+
result.chain,
235+
need_separately,
236+
modify_raw_chain=True,
237+
)
238+
for comp in sep_comps:
239+
chain = MessageChain([comp])
201240
try:
202-
await event.send(MessageChain([rcomp]))
241+
await event.send(chain)
203242
except Exception as e:
204-
logger.error(f"发送消息失败: {e} chain: {result.chain}")
205-
206-
try:
207-
await event.send(MessageChain(non_record_comps))
208-
except Exception as e:
209-
logger.error(traceback.format_exc())
210-
logger.error(f"发送消息失败: {e} chain: {result.chain}")
211-
212-
logger.info(
213-
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
214-
)
215-
216-
handlers = star_handlers_registry.get_handlers_by_event_type(
217-
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
218-
)
219-
for handler in handlers:
220-
try:
221-
logger.debug(
222-
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
223-
)
224-
await handler.handler(event)
225-
except BaseException:
226-
logger.error(traceback.format_exc())
243+
logger.error(
244+
f"发送消息链失败: chain = {chain}, error = {e}",
245+
exc_info=True,
246+
)
247+
chain = MessageChain(result.chain)
248+
if result.chain and len(result.chain) > 0:
249+
try:
250+
await event.send(chain)
251+
except Exception as e:
252+
logger.error(
253+
f"发送消息链失败: chain = {chain}, error = {e}",
254+
exc_info=True,
255+
)
227256

228-
if event.is_stopped():
229-
logger.info(
230-
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
231-
)
232-
return
257+
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
258+
return
233259

234260
event.clear_result()

0 commit comments

Comments
 (0)