Skip to content

Commit c501840

Browse files
committed
fix: clear deleted knowledge bases from sessions
1 parent c4693fa commit c501840

2 files changed

Lines changed: 81 additions & 1 deletion

File tree

astrbot/dashboard/routes/knowledge_base.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import aiofiles
1010
from quart import request
1111

12-
from astrbot.core import logger
12+
from astrbot.core import logger, sp
1313
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
1414
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
1515
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
@@ -69,6 +69,29 @@ def __init__(
6969
def _get_kb_manager(self):
7070
return self.core_lifecycle.kb_manager
7171

72+
@staticmethod
73+
async def _remove_kb_from_session_configs(kb_id: str) -> int:
74+
prefs = await sp.session_get(None, "kb_config")
75+
updated = 0
76+
77+
for pref in prefs:
78+
value = pref.value.get("val") if isinstance(pref.value, dict) else None
79+
if not isinstance(value, dict):
80+
continue
81+
82+
kb_ids = value.get("kb_ids")
83+
if not isinstance(kb_ids, list) or kb_id not in kb_ids:
84+
continue
85+
86+
new_value = {
87+
**value,
88+
"kb_ids": [item for item in kb_ids if item != kb_id],
89+
}
90+
await sp.session_put(pref.scope_id, "kb_config", new_value)
91+
updated += 1
92+
93+
return updated
94+
7295
def _init_task(self, task_id: str, status: str = "pending") -> None:
7396
self.upload_tasks[task_id] = {
7497
"status": status,
@@ -569,6 +592,12 @@ async def delete_kb(self):
569592
if not success:
570593
return Response().error("知识库不存在").__dict__
571594

595+
updated_sessions = await self._remove_kb_from_session_configs(kb_id)
596+
if updated_sessions:
597+
logger.info(
598+
f"已从 {updated_sessions} 个会话配置中移除已删除知识库 {kb_id}",
599+
)
600+
572601
return Response().ok(message="删除知识库成功").__dict__
573602

574603
except ValueError as e:

tests/test_kb_import.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import asyncio
2+
from types import SimpleNamespace
23
from unittest.mock import AsyncMock, MagicMock
34

45
import pytest
56
import pytest_asyncio
67
from quart import Quart
78

9+
import astrbot.dashboard.routes.knowledge_base as knowledge_base_route_module
810
from astrbot.core import LogBroker
911
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
1012
from astrbot.core.db.sqlite import SQLiteDatabase
@@ -117,6 +119,55 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc
117119
return {"Authorization": f"Bearer {token}"}
118120

119121

122+
@pytest.mark.asyncio
123+
async def test_remove_deleted_kb_from_session_configs(monkeypatch: pytest.MonkeyPatch):
124+
updates = []
125+
126+
class FakeSharedPreferences:
127+
async def session_get(self, umo, key):
128+
assert umo is None
129+
assert key == "kb_config"
130+
return [
131+
SimpleNamespace(
132+
scope_id="platform:GroupMessage:group!alice",
133+
value={"val": {"kb_ids": ["kb-old", "kb-keep"], "top_k": 3}},
134+
),
135+
SimpleNamespace(
136+
scope_id="platform:GroupMessage:group!bob",
137+
value={"val": {"kb_ids": ["kb-old"]}},
138+
),
139+
SimpleNamespace(
140+
scope_id="platform:FriendMessage:charlie",
141+
value={"val": {"kb_ids": ["kb-keep"]}},
142+
),
143+
SimpleNamespace(
144+
scope_id="platform:FriendMessage:broken",
145+
value={"val": {"kb_ids": "kb-old"}},
146+
),
147+
]
148+
149+
async def session_put(self, umo, key, value):
150+
updates.append((umo, key, value))
151+
152+
monkeypatch.setattr(knowledge_base_route_module, "sp", FakeSharedPreferences())
153+
154+
updated = await KnowledgeBaseRoute._remove_kb_from_session_configs("kb-old")
155+
156+
assert updated == 2
157+
assert updates == [
158+
(
159+
"platform:GroupMessage:group!alice",
160+
"kb_config",
161+
{"kb_ids": ["kb-keep"], "top_k": 3},
162+
),
163+
(
164+
"platform:GroupMessage:group!bob",
165+
"kb_config",
166+
{"kb_ids": []},
167+
),
168+
]
169+
170+
120171
@pytest.mark.asyncio
121172
async def test_import_documents(
122173
app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle

0 commit comments

Comments
 (0)