Skip to content

Commit 3597726

Browse files
authored
fix(core): terminate active events on reset/new/del to prevent stale responses (#5225)
* fix(core): terminate active events on reset/new/del to prevent stale responses Closes #5222 * style: fix import sorting in scheduler.py
1 parent a4a37c2 commit 3597726

3 files changed

Lines changed: 74 additions & 11 deletions

File tree

astrbot/builtin_stars/builtin_commands/commands/conversation.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from astrbot.api.event import AstrMessageEvent, MessageEventResult
55
from astrbot.core.platform.astr_message_event import MessageSession
66
from astrbot.core.platform.message_type import MessageType
7+
from astrbot.core.utils.active_event_registry import active_event_registry
78

89
from .utils.rst_scene import RstScene
910

@@ -62,6 +63,7 @@ async def reset(self, message: AstrMessageEvent) -> None:
6263

6364
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
6465
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
66+
active_event_registry.stop_all(umo, exclude=message)
6567
await sp.remove_async(
6668
scope="umo",
6769
scope_id=umo,
@@ -86,6 +88,8 @@ async def reset(self, message: AstrMessageEvent) -> None:
8688
)
8789
return
8890

91+
active_event_registry.stop_all(umo, exclude=message)
92+
8993
await self.context.conversation_manager.update_conversation(
9094
umo,
9195
cid,
@@ -221,6 +225,7 @@ async def new_conv(self, message: AstrMessageEvent) -> None:
221225
cfg = self.context.get_config(umo=message.unified_msg_origin)
222226
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
223227
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
228+
active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
224229
await sp.remove_async(
225230
scope="umo",
226231
scope_id=message.unified_msg_origin,
@@ -229,6 +234,7 @@ async def new_conv(self, message: AstrMessageEvent) -> None:
229234
message.set_result(MessageEventResult().message("已创建新对话。"))
230235
return
231236

237+
active_event_registry.stop_all(message.unified_msg_origin, exclude=message)
232238
cpersona = await self._get_current_persona_id(message.unified_msg_origin)
233239
cid = await self.context.conversation_manager.new_conversation(
234240
message.unified_msg_origin,
@@ -321,7 +327,8 @@ async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> No
321327

322328
async def del_conv(self, message: AstrMessageEvent) -> None:
323329
"""删除当前对话"""
324-
cfg = self.context.get_config(umo=message.unified_msg_origin)
330+
umo = message.unified_msg_origin
331+
cfg = self.context.get_config(umo=umo)
325332
is_unique_session = cfg["platform_settings"]["unique_session"]
326333
if message.get_group_id() and not is_unique_session and message.role != "admin":
327334
# 群聊,没开独立会话,发送人不是管理员
@@ -334,18 +341,17 @@ async def del_conv(self, message: AstrMessageEvent) -> None:
334341

335342
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
336343
if agent_runner_type in THIRD_PARTY_AGENT_RUNNER_KEY:
344+
active_event_registry.stop_all(umo, exclude=message)
337345
await sp.remove_async(
338346
scope="umo",
339-
scope_id=message.unified_msg_origin,
347+
scope_id=umo,
340348
key=THIRD_PARTY_AGENT_RUNNER_KEY[agent_runner_type],
341349
)
342350
message.set_result(MessageEventResult().message("重置对话成功。"))
343351
return
344352

345353
session_curr_cid = (
346-
await self.context.conversation_manager.get_curr_conversation_id(
347-
message.unified_msg_origin,
348-
)
354+
await self.context.conversation_manager.get_curr_conversation_id(umo)
349355
)
350356

351357
if not session_curr_cid:
@@ -356,8 +362,10 @@ async def del_conv(self, message: AstrMessageEvent) -> None:
356362
)
357363
return
358364

365+
active_event_registry.stop_all(umo, exclude=message)
366+
359367
await self.context.conversation_manager.delete_conversation(
360-
message.unified_msg_origin,
368+
umo,
361369
session_curr_cid,
362370
)
363371

astrbot/core/pipeline/scheduler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
77
WecomAIBotMessageEvent,
88
)
9+
from astrbot.core.utils.active_event_registry import active_event_registry
910

1011
from . import STAGES_ORDER
1112
from .context import PipelineContext
@@ -79,10 +80,14 @@ async def execute(self, event: AstrMessageEvent) -> None:
7980
event (AstrMessageEvent): 事件对象
8081
8182
"""
82-
await self._process_stages(event)
83+
active_event_registry.register(event)
84+
try:
85+
await self._process_stages(event)
8386

84-
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
85-
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
86-
await event.send(None)
87+
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
88+
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
89+
await event.send(None)
8790

88-
logger.debug("pipeline 执行完毕。")
91+
logger.debug("pipeline 执行完毕。")
92+
finally:
93+
active_event_registry.unregister(event)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from __future__ import annotations
2+
3+
from collections import defaultdict
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from astrbot.core.platform import AstrMessageEvent
8+
9+
10+
class ActiveEventRegistry:
11+
"""维护 unified_msg_origin 到活跃事件的映射。
12+
13+
用于在 reset 等场景下终止该会话正在处理的事件。
14+
"""
15+
16+
def __init__(self) -> None:
17+
self._events: dict[str, set[AstrMessageEvent]] = defaultdict(set)
18+
19+
def register(self, event: AstrMessageEvent) -> None:
20+
self._events[event.unified_msg_origin].add(event)
21+
22+
def unregister(self, event: AstrMessageEvent) -> None:
23+
umo = event.unified_msg_origin
24+
self._events[umo].discard(event)
25+
if not self._events[umo]:
26+
del self._events[umo]
27+
28+
def stop_all(
29+
self,
30+
umo: str,
31+
exclude: AstrMessageEvent | None = None,
32+
) -> int:
33+
"""终止指定 UMO 的所有活跃事件。
34+
35+
Args:
36+
umo: 统一消息来源标识符。
37+
exclude: 需要排除的事件(通常是发起 reset 的事件本身)。
38+
39+
Returns:
40+
被终止的事件数量。
41+
"""
42+
count = 0
43+
for event in list(self._events.get(umo, [])):
44+
if event is not exclude:
45+
event.stop_event()
46+
count += 1
47+
return count
48+
49+
50+
active_event_registry = ActiveEventRegistry()

0 commit comments

Comments
 (0)