Skip to content
64 changes: 58 additions & 6 deletions astrbot/core/knowledge_base/kb_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,41 @@ async def initialize(self) -> None:

async def get_ep(self) -> EmbeddingProvider:
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
self.kb.embedding_provider_id = "Embedding_Provider"
logger.error(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
ep: EmbeddingProvider = await self.prov_mgr.get_provider_by_id(
self.kb.embedding_provider_id,
) # type: ignore
if not ep:
raise ValueError(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider",
logger.error(
f"无法找到 ID 为 {self.kb.embedding_provider_id} 的 Embedding Provider,使用占位Embedding Provider"
)

class TempEmbeddingProvider(EmbeddingProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
embedding_provider_id: str,
) -> None:
super().__init__(provider_config, provider_settings)
self.embedding_provider_id = embedding_provider_id

async def get_embedding(self, text: str) -> list[float]:
raise ValueError(
f"无法找到 ID 为 {self.embedding_provider_id} 的 Embedding Provider"
)

async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
raise ValueError(
f"无法找到 ID 为 {self.embedding_provider_id} 的 Embedding Provider"
)

def get_dim(self) -> int:
Comment thread
Li-shi-ling marked this conversation as resolved.
return 512
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

ep: EmbeddingProvider = TempEmbeddingProvider(
{}, {}, self.kb.embedding_provider_id
Comment thread
Li-shi-ling marked this conversation as resolved.
)
return ep

Expand All @@ -152,14 +180,38 @@ async def get_rp(self) -> RerankProvider | None:
self.kb.rerank_provider_id,
) # type: ignore
if not rp:
raise ValueError(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider",
logger.error(
f"无法找到 ID 为 {self.kb.rerank_provider_id} 的 Rerank Provider,使用占位Rerank Provider"
)

class TempRerankProvider(RerankProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
rerank_provider_id: str,
) -> None:
super().__init__(provider_config, provider_settings)
self.rerank_provider_id = rerank_provider_id
Comment thread
sourcery-ai[bot] marked this conversation as resolved.

async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
):
raise ValueError(
f"无法找到 ID 为 {self.rerank_provider_id} 的 Rerank Provider"
)

rp: RerankProvider = TempRerankProvider({}, {}, self.kb.rerank_provider_id)
Comment thread
Li-shi-ling marked this conversation as resolved.
return rp

async def _ensure_vec_db(self) -> FaissVecDB:
if not self.kb.embedding_provider_id:
raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
self.kb.embedding_provider_id = "Embedding_Provider"
Comment thread
Li-shi-ling marked this conversation as resolved.
logger.error(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")
# raise ValueError(f"知识库 {self.kb.kb_name} 未配置 Embedding Provider")

ep = await self.get_ep()
rp = await self.get_rp()
Expand Down
Loading