Skip to content

Commit 4e29684

Browse files
authored
fix: add plugin set and knowledge bases selection in custom rules page (#3813)
fixes: #3806
1 parent 0e17e35 commit 4e29684

7 files changed

Lines changed: 274 additions & 347 deletions

File tree

astrbot/core/star/session_llm_manager.py

Lines changed: 0 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -171,110 +171,3 @@ def is_session_enabled(session_id: str) -> bool:
171171

172172
# 如果没有配置,默认为启用(兼容性考虑)
173173
return True
174-
175-
@staticmethod
176-
def set_session_status(session_id: str, enabled: bool) -> None:
177-
"""设置会话的整体启停状态
178-
179-
Args:
180-
session_id: 会话ID (unified_msg_origin)
181-
enabled: True表示启用,False表示禁用
182-
183-
"""
184-
session_config = (
185-
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
186-
)
187-
session_config["session_enabled"] = enabled
188-
sp.put(
189-
"session_service_config",
190-
session_config,
191-
scope="umo",
192-
scope_id=session_id,
193-
)
194-
195-
logger.info(
196-
f"会话 {session_id} 的整体状态已更新为: {'启用' if enabled else '禁用'}",
197-
)
198-
199-
@staticmethod
200-
def should_process_session_request(event: AstrMessageEvent) -> bool:
201-
"""检查是否应该处理会话请求(会话整体启停检查)
202-
203-
Args:
204-
event: 消息事件
205-
206-
Returns:
207-
bool: True表示应该处理,False表示跳过
208-
209-
"""
210-
session_id = event.unified_msg_origin
211-
return SessionServiceManager.is_session_enabled(session_id)
212-
213-
# =============================================================================
214-
# 会话命名相关方法
215-
# =============================================================================
216-
217-
@staticmethod
218-
def get_session_custom_name(session_id: str) -> str | None:
219-
"""获取会话的自定义名称
220-
221-
Args:
222-
session_id: 会话ID (unified_msg_origin)
223-
224-
Returns:
225-
str: 自定义名称,如果没有设置则返回None
226-
227-
"""
228-
session_services = sp.get(
229-
"session_service_config",
230-
{},
231-
scope="umo",
232-
scope_id=session_id,
233-
)
234-
return session_services.get("custom_name")
235-
236-
@staticmethod
237-
def set_session_custom_name(session_id: str, custom_name: str) -> None:
238-
"""设置会话的自定义名称
239-
240-
Args:
241-
session_id: 会话ID (unified_msg_origin)
242-
custom_name: 自定义名称,可以为空字符串来清除名称
243-
244-
"""
245-
session_config = (
246-
sp.get("session_service_config", {}, scope="umo", scope_id=session_id) or {}
247-
)
248-
if custom_name and custom_name.strip():
249-
session_config["custom_name"] = custom_name.strip()
250-
else:
251-
# 如果传入空名称,则删除自定义名称
252-
session_config.pop("custom_name", None)
253-
sp.put(
254-
"session_service_config",
255-
session_config,
256-
scope="umo",
257-
scope_id=session_id,
258-
)
259-
260-
logger.info(
261-
f"会话 {session_id} 的自定义名称已更新为: {custom_name.strip() if custom_name and custom_name.strip() else '已清除'}",
262-
)
263-
264-
@staticmethod
265-
def get_session_display_name(session_id: str) -> str:
266-
"""获取会话的显示名称(优先显示自定义名称,否则显示原始session_id的最后一段)
267-
268-
Args:
269-
session_id: 会话ID (unified_msg_origin)
270-
271-
Returns:
272-
str: 显示名称
273-
274-
"""
275-
custom_name = SessionServiceManager.get_session_custom_name(session_id)
276-
if custom_name:
277-
return custom_name
278-
279-
# 如果没有自定义名称,返回session_id的最后一段
280-
return session_id.split(":")[2] if session_id.count(":") >= 2 else session_id

astrbot/core/star/session_plugin_manager.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -42,87 +42,6 @@ def is_plugin_enabled_for_session(session_id: str, plugin_name: str) -> bool:
4242
# 如果都没有配置,默认为启用(兼容性考虑)
4343
return True
4444

45-
@staticmethod
46-
def set_plugin_status_for_session(
47-
session_id: str,
48-
plugin_name: str,
49-
enabled: bool,
50-
) -> None:
51-
"""设置插件在指定会话中的启停状态
52-
53-
Args:
54-
session_id: 会话ID (unified_msg_origin)
55-
plugin_name: 插件名称
56-
enabled: True表示启用,False表示禁用
57-
58-
"""
59-
# 获取当前配置
60-
session_plugin_config = sp.get(
61-
"session_plugin_config",
62-
{},
63-
scope="umo",
64-
scope_id=session_id,
65-
)
66-
if session_id not in session_plugin_config:
67-
session_plugin_config[session_id] = {
68-
"enabled_plugins": [],
69-
"disabled_plugins": [],
70-
}
71-
72-
session_config = session_plugin_config[session_id]
73-
enabled_plugins = session_config.get("enabled_plugins", [])
74-
disabled_plugins = session_config.get("disabled_plugins", [])
75-
76-
if enabled:
77-
# 启用插件
78-
if plugin_name in disabled_plugins:
79-
disabled_plugins.remove(plugin_name)
80-
if plugin_name not in enabled_plugins:
81-
enabled_plugins.append(plugin_name)
82-
else:
83-
# 禁用插件
84-
if plugin_name in enabled_plugins:
85-
enabled_plugins.remove(plugin_name)
86-
if plugin_name not in disabled_plugins:
87-
disabled_plugins.append(plugin_name)
88-
89-
# 保存配置
90-
session_config["enabled_plugins"] = enabled_plugins
91-
session_config["disabled_plugins"] = disabled_plugins
92-
session_plugin_config[session_id] = session_config
93-
sp.put(
94-
"session_plugin_config",
95-
session_plugin_config,
96-
scope="umo",
97-
scope_id=session_id,
98-
)
99-
100-
logger.info(
101-
f"会话 {session_id} 的插件 {plugin_name} 状态已更新为: {'启用' if enabled else '禁用'}",
102-
)
103-
104-
@staticmethod
105-
def get_session_plugin_config(session_id: str) -> dict[str, list[str]]:
106-
"""获取指定会话的插件配置
107-
108-
Args:
109-
session_id: 会话ID (unified_msg_origin)
110-
111-
Returns:
112-
Dict[str, List[str]]: 包含enabled_plugins和disabled_plugins的字典
113-
114-
"""
115-
session_plugin_config = sp.get(
116-
"session_plugin_config",
117-
{},
118-
scope="umo",
119-
scope_id=session_id,
120-
)
121-
return session_plugin_config.get(
122-
session_id,
123-
{"enabled_plugins": [], "disabled_plugins": []},
124-
)
125-
12645
@staticmethod
12746
def filter_handlers_by_session(event: AstrMessageEvent, handlers: list) -> list:
12847
"""根据会话配置过滤处理器列表

astrbot/dashboard/routes/knowledge_base.py

Lines changed: 0 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ def __init__(
6060
# "/kb/media/delete": ("POST", self.delete_media),
6161
# 检索
6262
"/kb/retrieve": ("POST", self.retrieve),
63-
# 会话知识库配置
64-
"/kb/session/config/get": ("GET", self.get_session_kb_config),
65-
"/kb/session/config/set": ("POST", self.set_session_kb_config),
66-
"/kb/session/config/delete": ("POST", self.delete_session_kb_config),
6763
}
6864
self.register_routes()
6965

@@ -920,158 +916,6 @@ async def retrieve(self):
920916
logger.error(traceback.format_exc())
921917
return Response().error(f"检索失败: {e!s}").__dict__
922918

923-
# ===== 会话知识库配置 API =====
924-
925-
async def get_session_kb_config(self):
926-
"""获取会话的知识库配置
927-
928-
Query 参数:
929-
- session_id: 会话 ID (必填)
930-
931-
返回:
932-
- kb_ids: 知识库 ID 列表
933-
- top_k: 返回结果数量
934-
- enable_rerank: 是否启用重排序
935-
"""
936-
try:
937-
from astrbot.core import sp
938-
939-
session_id = request.args.get("session_id")
940-
941-
if not session_id:
942-
return Response().error("缺少参数 session_id").__dict__
943-
944-
# 从 SharedPreferences 获取配置
945-
config = await sp.session_get(session_id, "kb_config", default={})
946-
947-
logger.debug(f"[KB配置] 读取到配置: session_id={session_id}")
948-
949-
# 如果没有配置,返回默认值
950-
if not config:
951-
config = {"kb_ids": [], "top_k": 5, "enable_rerank": True}
952-
953-
return Response().ok(config).__dict__
954-
955-
except Exception as e:
956-
logger.error(f"[KB配置] 获取配置时出错: {e}", exc_info=True)
957-
return Response().error(f"获取会话知识库配置失败: {e!s}").__dict__
958-
959-
async def set_session_kb_config(self):
960-
"""设置会话的知识库配置
961-
962-
Body:
963-
- scope: 配置范围 (目前只支持 "session")
964-
- scope_id: 会话 ID (必填)
965-
- kb_ids: 知识库 ID 列表 (必填)
966-
- top_k: 返回结果数量 (可选, 默认 5)
967-
- enable_rerank: 是否启用重排序 (可选, 默认 true)
968-
"""
969-
try:
970-
from astrbot.core import sp
971-
972-
data = await request.json
973-
974-
scope = data.get("scope")
975-
scope_id = data.get("scope_id")
976-
kb_ids = data.get("kb_ids", [])
977-
top_k = data.get("top_k", 5)
978-
enable_rerank = data.get("enable_rerank", True)
979-
980-
# 验证参数
981-
if scope != "session":
982-
return Response().error("目前仅支持 session 范围的配置").__dict__
983-
984-
if not scope_id:
985-
return Response().error("缺少参数 scope_id").__dict__
986-
987-
if not isinstance(kb_ids, list):
988-
return Response().error("kb_ids 必须是列表").__dict__
989-
990-
# 验证知识库是否存在
991-
kb_mgr = self._get_kb_manager()
992-
invalid_ids = []
993-
valid_ids = []
994-
for kb_id in kb_ids:
995-
kb_helper = await kb_mgr.get_kb(kb_id)
996-
if kb_helper:
997-
valid_ids.append(kb_id)
998-
else:
999-
invalid_ids.append(kb_id)
1000-
logger.warning(f"[KB配置] 知识库不存在: {kb_id}")
1001-
1002-
if invalid_ids:
1003-
logger.warning(f"[KB配置] 以下知识库ID无效: {invalid_ids}")
1004-
1005-
# 允许保存空列表,表示明确不使用任何知识库
1006-
if kb_ids and not valid_ids:
1007-
# 只有当用户提供了 kb_ids 但全部无效时才报错
1008-
return Response().error(f"所有提供的知识库ID都无效: {kb_ids}").__dict__
1009-
1010-
# 如果 kb_ids 为空列表,表示用户想清空配置
1011-
if not kb_ids:
1012-
valid_ids = []
1013-
1014-
# 构建配置对象(只保存有效的ID)
1015-
config = {
1016-
"kb_ids": valid_ids,
1017-
"top_k": top_k,
1018-
"enable_rerank": enable_rerank,
1019-
}
1020-
1021-
# 保存到 SharedPreferences
1022-
await sp.session_put(scope_id, "kb_config", config)
1023-
1024-
# 立即验证是否保存成功
1025-
verify_config = await sp.session_get(scope_id, "kb_config", default={})
1026-
1027-
if verify_config == config:
1028-
return (
1029-
Response()
1030-
.ok(
1031-
{"valid_ids": valid_ids, "invalid_ids": invalid_ids},
1032-
"保存知识库配置成功",
1033-
)
1034-
.__dict__
1035-
)
1036-
logger.error("[KB配置] 配置保存失败,验证不匹配")
1037-
return Response().error("配置保存失败").__dict__
1038-
1039-
except Exception as e:
1040-
logger.error(f"[KB配置] 设置配置时出错: {e}", exc_info=True)
1041-
return Response().error(f"设置会话知识库配置失败: {e!s}").__dict__
1042-
1043-
async def delete_session_kb_config(self):
1044-
"""删除会话的知识库配置
1045-
1046-
Body:
1047-
- scope: 配置范围 (目前只支持 "session")
1048-
- scope_id: 会话 ID (必填)
1049-
"""
1050-
try:
1051-
from astrbot.core import sp
1052-
1053-
data = await request.json
1054-
1055-
scope = data.get("scope")
1056-
scope_id = data.get("scope_id")
1057-
1058-
# 验证参数
1059-
if scope != "session":
1060-
return Response().error("目前仅支持 session 范围的配置").__dict__
1061-
1062-
if not scope_id:
1063-
return Response().error("缺少参数 scope_id").__dict__
1064-
1065-
# 从 SharedPreferences 删除配置
1066-
await sp.session_remove(scope_id, "kb_config")
1067-
1068-
return Response().ok(message="删除知识库配置成功").__dict__
1069-
1070-
except Exception as e:
1071-
logger.error(f"删除会话知识库配置失败: {e}")
1072-
logger.error(traceback.format_exc())
1073-
return Response().error(f"删除会话知识库配置失败: {e!s}").__dict__
1074-
1075919
async def upload_document_from_url(self):
1076920
"""从 URL 上传文档
1077921

0 commit comments

Comments
 (0)