diff --git a/astrbot/core/knowledge_base/kb_helper.py b/astrbot/core/knowledge_base/kb_helper.py index 1e9127d72a..93e0cf9bed 100644 --- a/astrbot/core/knowledge_base/kb_helper.py +++ b/astrbot/core/knowledge_base/kb_helper.py @@ -135,13 +135,41 @@ async def initialize(self) -> None: async def get_ep(self) -> EmbeddingProvider: if not self.kb.embedding_provider_id: - raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") + self.kb.embedding_provider_id = "Embedding_Provider" + logger.error(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id( self.kb.embedding_provider_id, ) # type: ignore if not ep: - raise ValueError( - f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider", + logger.error( + f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider,使用占位Embedding Provider" + ) + + class TempEmbeddingProvider(EmbeddingProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + embedding_provider_id: str, + ) -> None: + super().__init__(provider_config, provider_settings) + self.embedding_provider_id = embedding_provider_id + + async def get_embedding(self, text: str) -> list[float]: + raise ValueError( + f"无法找到 ID 为 {self.embedding_provider_id} 的 Embedding Provider" + ) + + async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + raise ValueError( + f"无法找到 ID 为 {self.embedding_provider_id} 的 Embedding Provider" + ) + + def get_dim(self) -> int: + return 512 + + ep: EmbeddingProvider = TempEmbeddingProvider( + {}, {}, self.kb.embedding_provider_id ) return ep @@ -152,14 +180,38 @@ async def get_rp(self) -> RerankProvider | None: self.kb.rerank_provider_id, ) # type: ignore if not rp: - raise ValueError( - f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider", + logger.error( + f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider,使用占位Rerank Provider" ) + + class TempRerankProvider(RerankProvider): + def __init__( + self, + provider_config: dict, + provider_settings: dict, + rerank_provider_id: str, + ) -> None: + super().__init__(provider_config, provider_settings) + self.rerank_provider_id = rerank_provider_id + + async def rerank( + self, + query: str, + documents: list[str], + top_n: int | None = None, + ): + raise ValueError( + f"无法找到 ID 为 {self.rerank_provider_id} 的 Rerank Provider" + ) + + rp: RerankProvider = TempRerankProvider({}, {}, self.kb.rerank_provider_id) return rp async def _ensure_vec_db(self) -> FaissVecDB: if not self.kb.embedding_provider_id: - raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") + self.kb.embedding_provider_id = "Embedding_Provider" + logger.error(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") + # raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider") ep = await self.get_ep() rp = await self.get_rp()