Skip to content

Commit 7fbc192

Browse files
committed
fix: defer faiss imports during startup
1 parent 8c6c00a commit 7fbc192

12 files changed

Lines changed: 145 additions & 17 deletions

File tree

astrbot/core/config/default.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,7 @@ class ChatProviderTemplate(TypedDict):
15561556
"enable": False,
15571557
"id": "whisper_selfhost",
15581558
"model": "tiny",
1559+
"whisper_device": "cpu",
15591560
},
15601561
"SenseVoice(Local)": {
15611562
"type": "sensevoice_stt_selfhost",
@@ -2555,6 +2556,12 @@ class ChatProviderTemplate(TypedDict):
25552556
"type": "string",
25562557
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
25572558
},
2559+
"whisper_device": {
2560+
"description": "推理设备",
2561+
"type": "string",
2562+
"hint": "Whisper 推理设备。Apple Silicon 可选 mps;其他环境建议使用 cpu。若指定 mps 但当前环境不可用,将自动回退到 cpu。",
2563+
"options": ["cpu", "mps"],
2564+
},
25582565
"id": {
25592566
"description": "ID",
25602567
"type": "string",

astrbot/core/knowledge_base/kb_db_sqlite.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from contextlib import asynccontextmanager
22
from pathlib import Path
3+
from typing import TYPE_CHECKING
34

45
from sqlalchemy import delete, func, select, text, update
56
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
67
from sqlmodel import col, desc
78

89
from astrbot.core import logger
9-
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
1010
from astrbot.core.knowledge_base.models import (
1111
BaseKBModel,
1212
KBDocument,
@@ -15,6 +15,9 @@
1515
)
1616
from astrbot.core.utils.astrbot_path import get_astrbot_knowledge_base_path
1717

18+
if TYPE_CHECKING:
19+
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
20+
1821

1922
class KBSQLiteDatabase:
2023
def __init__(self, db_path: str | None = None) -> None:
@@ -296,7 +299,7 @@ async def get_documents_with_metadata_batch(
296299

297300
return metadata_map
298301

299-
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None:
302+
async def delete_document_by_id(self, doc_id: str, vec_db: "FaissVecDB") -> None:
300303
"""删除单个文档及其相关数据"""
301304
# 在知识库表中删除
302305
async with self.get_db() as session, session.begin():
@@ -324,7 +327,7 @@ async def get_media_by_id(self, media_id: str) -> KBMedia | None:
324327
result = await session.execute(stmt)
325328
return result.scalar_one_or_none()
326329

327-
async def update_kb_stats(self, kb_id: str, vec_db: FaissVecDB) -> None:
330+
async def update_kb_stats(self, kb_id: str, vec_db: "FaissVecDB") -> None:
328331
"""更新知识库统计信息"""
329332
chunk_cnt = await vec_db.count_documents()
330333

astrbot/core/knowledge_base/kb_helper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import time
55
import uuid
66
from pathlib import Path
7+
from typing import TYPE_CHECKING
78

89
import aiofiles
910

1011
from astrbot.core import logger
1112
from astrbot.core.db.vec_db.base import BaseVecDB
12-
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
1313
from astrbot.core.provider.manager import ProviderManager
1414
from astrbot.core.provider.provider import (
1515
EmbeddingProvider,
@@ -27,6 +27,9 @@
2727
from .parsers.util import select_parser
2828
from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
2929

30+
if TYPE_CHECKING:
31+
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
32+
3033

3134
class RateLimiter:
3235
"""一个简单的速率限制器"""
@@ -160,7 +163,7 @@ async def get_rp(self) -> RerankProvider | None:
160163
return None
161164
return rp
162165

163-
async def _ensure_vec_db(self) -> FaissVecDB:
166+
async def _ensure_vec_db(self) -> "FaissVecDB":
164167
if not self.kb.embedding_provider_id:
165168
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
166169

@@ -173,6 +176,8 @@ async def _ensure_vec_db(self) -> FaissVecDB:
173176
f"知识库 {self.kb.kb_name}({self.kb.kb_id}) 初始化重排序能力失败,将跳过重排序: {e}",
174177
)
175178

179+
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
180+
176181
vec_db = FaissVecDB(
177182
doc_store_path=str(self.kb_dir / "doc.db"),
178183
index_store_path=str(self.kb_dir / "index.faiss"),

astrbot/core/knowledge_base/retrieval/manager.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55

66
import time
77
from dataclasses import dataclass
8+
from typing import TYPE_CHECKING
89

910
from astrbot import logger
1011
from astrbot.core.db.vec_db.base import Result
11-
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
1212
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
1313
from astrbot.core.knowledge_base.retrieval.rank_fusion import RankFusion
1414
from astrbot.core.knowledge_base.retrieval.sparse_retriever import SparseRetriever
1515
from astrbot.core.provider.provider import RerankProvider
1616

1717
from ..kb_helper import KBHelper
1818

19+
if TYPE_CHECKING:
20+
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
21+
1922

2023
@dataclass
2124
class RetrievalResult:
@@ -170,18 +173,18 @@ async def retrieve(
170173
first_rerank = None
171174
for kb_id in kb_ids:
172175
vec_db = kb_options[kb_id]["vec_db"]
173-
if not isinstance(vec_db, FaissVecDB):
174-
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
176+
rerank_provider = getattr(vec_db, "rerank_provider", None)
177+
if rerank_provider is None:
175178
continue
176179

177180
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
178181
if (
179182
vec_db
180-
and vec_db.rerank_provider
183+
and rerank_provider
181184
and rerank_pi
182-
and rerank_pi == vec_db.rerank_provider.meta().id
185+
and rerank_pi == rerank_provider.meta().id
183186
):
184-
first_rerank = vec_db.rerank_provider
187+
first_rerank = rerank_provider
185188
break
186189
if first_rerank and retrieval_results:
187190
try:

astrbot/core/knowledge_base/retrieval/sparse_retriever.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66
import json
77
import os
88
from dataclasses import dataclass
9+
from typing import TYPE_CHECKING
910

1011
import jieba
1112
from rank_bm25 import BM25Okapi
1213

13-
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
1414
from astrbot.core.knowledge_base.kb_db_sqlite import KBSQLiteDatabase
1515

16+
if TYPE_CHECKING:
17+
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
18+
1619

1720
@dataclass
1821
class SparseResult:
@@ -73,7 +76,7 @@ async def retrieve(
7376
top_k_sparse = 0
7477
chunks = []
7578
for kb_id in kb_ids:
76-
vec_db: FaissVecDB = kb_options.get(kb_id, {}).get("vec_db")
79+
vec_db: FaissVecDB | None = kb_options.get(kb_id, {}).get("vec_db")
7780
if not vec_db:
7881
continue
7982
result = await vec_db.document_storage.get_documents(

astrbot/core/provider/sources/whisper_selfhosted_source.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import asyncio
22
import os
33
import uuid
4+
from functools import partial
45
from typing import cast
56

7+
import torch
68
import whisper
79

810
from astrbot.core import logger
@@ -28,17 +30,32 @@ def __init__(
2830
) -> None:
2931
super().__init__(provider_config, provider_settings)
3032
self.set_model(provider_config["model"])
33+
self.device = str(provider_config.get("whisper_device", "cpu")).strip().lower()
3134
self.model = None
3235

36+
def _resolve_device(self) -> str:
37+
if self.device == "mps":
38+
mps_backend = getattr(torch.backends, "mps", None)
39+
if mps_backend and mps_backend.is_available():
40+
return "mps"
41+
logger.warning("Whisper 已配置为使用 MPS,但当前环境不可用,将回退到 CPU。")
42+
return "cpu"
43+
if self.device != "cpu":
44+
logger.warning(
45+
"Whisper 配置了未知 device=%s,将回退到 CPU。",
46+
self.device,
47+
)
48+
return "cpu"
49+
3350
async def initialize(self) -> None:
3451
loop = asyncio.get_running_loop()
52+
device = self._resolve_device()
3553
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
3654
self.model = await loop.run_in_executor(
3755
None,
38-
whisper.load_model,
39-
self.model_name,
56+
partial(whisper.load_model, self.model_name, device=device),
4057
)
41-
logger.info("Whisper 模型加载完成。")
58+
logger.info("Whisper 模型加载完成。device=%s", device)
4259

4360
async def _is_silk_file(self, file_path) -> bool:
4461
silk_header = b"SILK"

astrbot/dashboard/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import base64
22
import traceback
33
from io import BytesIO
4+
from typing import TYPE_CHECKING
45

56
from astrbot.api import logger
6-
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
77
from astrbot.core.knowledge_base.kb_helper import KBHelper
88
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
99

10+
if TYPE_CHECKING:
11+
from astrbot.core.db.vec_db.faiss_impl import FaissVecDB
12+
1013

1114
async def generate_tsne_visualization(
1215
query: str,

dashboard/src/i18n/locales/en-US/features/config-metadata.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1514,6 +1514,10 @@
15141514
"description": "Notes for local Whisper deployment",
15151515
"hint": "Before enabling, install the openai-whisper library (NVIDIA users download ~2GB mainly for torch and cuda; CPU users download ~1GB), and install ffmpeg. Otherwise STT will not work."
15161516
},
1517+
"whisper_device": {
1518+
"description": "Inference device",
1519+
"hint": "Whisper inference device. Apple Silicon can use mps; other environments should use cpu. If mps is selected but unavailable, AstrBot will fall back to cpu."
1520+
},
15171521
"id": {
15181522
"description": "ID"
15191523
},

dashboard/src/i18n/locales/ru-RU/features/config-metadata.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,10 @@
15111511
"description": "Заметки по локальному развертыванию Whisper",
15121512
"hint": "Перед включением установите openai-whisper и ffmpeg."
15131513
},
1514+
"whisper_device": {
1515+
"description": "Устройство инференса",
1516+
"hint": "Устройство для инференса Whisper. На Apple Silicon можно выбрать mps; в остальных средах рекомендуется cpu. Если выбран mps, но он недоступен, AstrBot автоматически переключится на cpu."
1517+
},
15141518
"id": {
15151519
"description": "ID провайдера"
15161520
},

dashboard/src/i18n/locales/zh-CN/features/config-metadata.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,10 @@
15161516
"description": "本地部署 Whisper 模型须知",
15171517
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。"
15181518
},
1519+
"whisper_device": {
1520+
"description": "推理设备",
1521+
"hint": "Whisper 推理设备。Apple Silicon 可选 mps;其他环境建议使用 cpu。若指定 mps 但当前环境不可用,将自动回退到 cpu。"
1522+
},
15191523
"id": {
15201524
"description": "ID"
15211525
},

0 commit comments

Comments
 (0)