Skip to content

Commit 5ce1775

Browse files
committed
fix: clear deleted knowledge bases from sessions
1 parent c4693fa commit 5ce1775

2 files changed

Lines changed: 119 additions & 1 deletion

File tree

astrbot/dashboard/routes/knowledge_base.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import aiofiles
1010
from quart import request
1111

12-
from astrbot.core import logger
12+
from astrbot.core import logger, sp
1313
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
1414
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
1515
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
@@ -69,6 +69,36 @@ def __init__(
6969
def _get_kb_manager(self):
7070
return self.core_lifecycle.kb_manager
7171

72+
@staticmethod
73+
async def _remove_kb_from_session_configs(kb_id: str) -> int:
74+
prefs = await sp.session_get(None, "kb_config")
75+
if not isinstance(prefs, list):
76+
return 0
77+
78+
updated = 0
79+
80+
for pref in prefs:
81+
scope_id = getattr(pref, "scope_id", None)
82+
if not isinstance(scope_id, str):
83+
continue
84+
85+
value = await sp.session_get(scope_id, "kb_config")
86+
if not isinstance(value, dict):
87+
continue
88+
89+
kb_ids = value.get("kb_ids")
90+
if not isinstance(kb_ids, list) or kb_id not in kb_ids:
91+
continue
92+
93+
new_value = {
94+
**value,
95+
"kb_ids": [item for item in kb_ids if item != kb_id],
96+
}
97+
await sp.session_put(scope_id, "kb_config", new_value)
98+
updated += 1
99+
100+
return updated
101+
72102
def _init_task(self, task_id: str, status: str = "pending") -> None:
73103
self.upload_tasks[task_id] = {
74104
"status": status,
@@ -569,6 +599,12 @@ async def delete_kb(self):
569599
if not success:
570600
return Response().error("知识库不存在").__dict__
571601

602+
updated_sessions = await self._remove_kb_from_session_configs(kb_id)
603+
if updated_sessions:
604+
logger.info(
605+
f"已从 {updated_sessions} 个会话配置中移除已删除知识库 {kb_id}",
606+
)
607+
572608
return Response().ok(message="删除知识库成功").__dict__
573609

574610
except ValueError as e:

tests/test_kb_import.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import asyncio
2+
from types import SimpleNamespace
23
from unittest.mock import AsyncMock, MagicMock
34

45
import pytest
56
import pytest_asyncio
67
from quart import Quart
78

9+
import astrbot.dashboard.routes.knowledge_base as knowledge_base_route_module
810
from astrbot.core import LogBroker
911
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
1012
from astrbot.core.db.sqlite import SQLiteDatabase
@@ -117,6 +119,86 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc
117119
return {"Authorization": f"Bearer {token}"}
118120

119121

122+
@pytest.mark.asyncio
123+
async def test_remove_deleted_kb_from_session_configs(monkeypatch: pytest.MonkeyPatch):
124+
updates = []
125+
configs = {
126+
"platform:GroupMessage:group!alice": {
127+
"kb_ids": ["kb-old", "kb-keep"],
128+
"top_k": 3,
129+
},
130+
"platform:GroupMessage:group!bob": {"kb_ids": ["kb-old"]},
131+
"platform:FriendMessage:charlie": {"kb_ids": ["kb-keep"]},
132+
"platform:FriendMessage:broken": {"kb_ids": "kb-old"},
133+
}
134+
135+
class FakeSharedPreferences:
136+
async def session_get(self, umo, key):
137+
assert key == "kb_config"
138+
if umo is not None:
139+
return configs.get(umo)
140+
141+
return [
142+
SimpleNamespace(
143+
scope_id="platform:GroupMessage:group!alice",
144+
value={"val": configs["platform:GroupMessage:group!alice"]},
145+
),
146+
SimpleNamespace(
147+
scope_id="platform:GroupMessage:group!bob",
148+
value={"val": configs["platform:GroupMessage:group!bob"]},
149+
),
150+
SimpleNamespace(
151+
scope_id="platform:FriendMessage:charlie",
152+
value={"val": configs["platform:FriendMessage:charlie"]},
153+
),
154+
SimpleNamespace(
155+
scope_id="platform:FriendMessage:broken",
156+
value={"val": configs["platform:FriendMessage:broken"]},
157+
),
158+
]
159+
160+
async def session_put(self, umo, key, value):
161+
updates.append((umo, key, value))
162+
163+
monkeypatch.setattr(knowledge_base_route_module, "sp", FakeSharedPreferences())
164+
165+
updated = await KnowledgeBaseRoute._remove_kb_from_session_configs("kb-old")
166+
167+
assert updated == 2
168+
assert updates == [
169+
(
170+
"platform:GroupMessage:group!alice",
171+
"kb_config",
172+
{"kb_ids": ["kb-keep"], "top_k": 3},
173+
),
174+
(
175+
"platform:GroupMessage:group!bob",
176+
"kb_config",
177+
{"kb_ids": []},
178+
),
179+
]
180+
181+
182+
@pytest.mark.asyncio
183+
async def test_remove_deleted_kb_ignores_missing_session_list(
184+
monkeypatch: pytest.MonkeyPatch,
185+
):
186+
class FakeSharedPreferences:
187+
async def session_get(self, umo, key):
188+
assert umo is None
189+
assert key == "kb_config"
190+
return None
191+
192+
async def session_put(self, umo, key, value):
193+
raise AssertionError("session_put should not be called")
194+
195+
monkeypatch.setattr(knowledge_base_route_module, "sp", FakeSharedPreferences())
196+
197+
updated = await KnowledgeBaseRoute._remove_kb_from_session_configs("kb-old")
198+
199+
assert updated == 0
200+
201+
120202
@pytest.mark.asyncio
121203
async def test_import_documents(
122204
app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle

0 commit comments

Comments
 (0)