Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion astrbot/dashboard/routes/knowledge_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import aiofiles
from quart import request

from astrbot.core import logger
from astrbot.core import logger, sp
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
Expand Down Expand Up @@ -69,6 +69,36 @@ def __init__(
def _get_kb_manager(self):
return self.core_lifecycle.kb_manager

@staticmethod
async def _remove_kb_from_session_configs(kb_id: str) -> int:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for cleaning up session configurations when a knowledge base is deleted is business logic that should ideally reside in the KnowledgeBaseManager (in astrbot/core/knowledge_base/kb_mgr.py) rather than the API route. Moving it to the core manager ensures that the cleanup happens regardless of how the deletion is triggered (e.g., via a CLI, a scheduled task, or another internal component), maintaining data consistency across the system.

prefs = await sp.session_get(None, "kb_config")
if not isinstance(prefs, list):
return 0

updated = 0

for pref in prefs:
scope_id = getattr(pref, "scope_id", None)
if not isinstance(scope_id, str):
continue

value = await sp.session_get(scope_id, "kb_config")
if not isinstance(value, dict):
continue

kb_ids = value.get("kb_ids")
if not isinstance(kb_ids, list) or kb_id not in kb_ids:
continue

new_value = {
**value,
"kb_ids": [item for item in kb_ids if item != kb_id],
}
await sp.session_put(scope_id, "kb_config", new_value)
updated += 1

return updated

def _init_task(self, task_id: str, status: str = "pending") -> None:
self.upload_tasks[task_id] = {
"status": status,
Expand Down Expand Up @@ -569,6 +599,12 @@ async def delete_kb(self):
if not success:
return Response().error("知识库不存在").__dict__

updated_sessions = await self._remove_kb_from_session_configs(kb_id)
if updated_sessions:
logger.info(
f"已从 {updated_sessions} 个会话配置中移除已删除知识库 {kb_id}",
)

return Response().ok(message="删除知识库成功").__dict__

except ValueError as e:
Expand Down
82 changes: 82 additions & 0 deletions tests/test_kb_import.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest
import pytest_asyncio
from quart import Quart

import astrbot.dashboard.routes.knowledge_base as knowledge_base_route_module
from astrbot.core import LogBroker
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db.sqlite import SQLiteDatabase
Expand Down Expand Up @@ -117,6 +119,86 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc
return {"Authorization": f"Bearer {token}"}


@pytest.mark.asyncio
async def test_remove_deleted_kb_from_session_configs(monkeypatch: pytest.MonkeyPatch):
updates = []
configs = {
"platform:GroupMessage:group!alice": {
"kb_ids": ["kb-old", "kb-keep"],
"top_k": 3,
},
"platform:GroupMessage:group!bob": {"kb_ids": ["kb-old"]},
"platform:FriendMessage:charlie": {"kb_ids": ["kb-keep"]},
"platform:FriendMessage:broken": {"kb_ids": "kb-old"},
}

class FakeSharedPreferences:
async def session_get(self, umo, key):
assert key == "kb_config"
if umo is not None:
return configs.get(umo)

return [
SimpleNamespace(
scope_id="platform:GroupMessage:group!alice",
value={"val": configs["platform:GroupMessage:group!alice"]},
),
SimpleNamespace(
scope_id="platform:GroupMessage:group!bob",
value={"val": configs["platform:GroupMessage:group!bob"]},
),
SimpleNamespace(
scope_id="platform:FriendMessage:charlie",
value={"val": configs["platform:FriendMessage:charlie"]},
),
SimpleNamespace(
scope_id="platform:FriendMessage:broken",
value={"val": configs["platform:FriendMessage:broken"]},
),
]

async def session_put(self, umo, key, value):
updates.append((umo, key, value))

monkeypatch.setattr(knowledge_base_route_module, "sp", FakeSharedPreferences())

updated = await KnowledgeBaseRoute._remove_kb_from_session_configs("kb-old")

assert updated == 2
assert updates == [
(
"platform:GroupMessage:group!alice",
"kb_config",
{"kb_ids": ["kb-keep"], "top_k": 3},
),
(
"platform:GroupMessage:group!bob",
"kb_config",
{"kb_ids": []},
),
]


@pytest.mark.asyncio
async def test_remove_deleted_kb_ignores_missing_session_list(
monkeypatch: pytest.MonkeyPatch,
):
class FakeSharedPreferences:
async def session_get(self, umo, key):
assert umo is None
assert key == "kb_config"
return None

async def session_put(self, umo, key, value):
raise AssertionError("session_put should not be called")

monkeypatch.setattr(knowledge_base_route_module, "sp", FakeSharedPreferences())

updated = await KnowledgeBaseRoute._remove_kb_from_session_configs("kb-old")

assert updated == 0


@pytest.mark.asyncio
async def test_import_documents(
app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle
Expand Down
Loading