diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 846b1a350e..003359064f 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1789,6 +1789,19 @@ class ChatProviderTemplate(TypedDict): "return_documents": False, "instruct": "", }, + "NVIDIA Rerank": { + "id": "nvidia_rerank", + "type": "nvidia_rerank", + "provider": "nvidia", + "provider_type": "rerank", + "enable": True, + "nvidia_rerank_api_key": "", + "nvidia_rerank_api_base": "https://ai.api.nvidia.com/v1/retrieval", + "nvidia_rerank_model": "nv-rerank-qa-mistral-4b:1", + "nvidia_rerank_model_endpoint": "/reranking", + "timeout": 20, + "nvidia_rerank_truncate": "", + }, "Xinference STT": { "id": "xinference_stt", "type": "xinference_stt", @@ -1852,6 +1865,34 @@ class ChatProviderTemplate(TypedDict): "type": "bool", "hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。", }, + "nvidia_rerank_api_base": { + "description": "API Base URL", + "type": "string", + }, + "nvidia_rerank_api_key": { + "description": "API Key", + "type": "string", + }, + "nvidia_rerank_model": { + "description": "重排序模型名称", + "type": "string", + "hint": "请参照NVIDIA Docs中模型名称填写。", + }, + "nvidia_rerank_model_endpoint": { + "description": "自定义模型端点", + "type": "string", + "hint": "自定义URL末尾端点,默认为 /reranking", + }, + "nvidia_rerank_truncate": { + "description": "文本截断策略", + "type": "string", + "hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。", + "options": [ + "", + "NONE", + "END", + ], + }, "modalities": { "description": "模型能力", "type": "list", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7a3e1543a7..dfed10261d 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -471,6 +471,10 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.bailian_rerank_source import ( BailianRerankProvider as BailianRerankProvider, ) + case "nvidia_rerank": + from .sources.nvidia_rerank_source import ( + NvidiaRerankProvider as NvidiaRerankProvider, + ) def get_merged_provider_config(self, provider_config: dict) -> dict: """获取 provider 配置和 provider_source 配置合并后的结果 diff --git a/astrbot/core/provider/sources/nvidia_rerank_source.py b/astrbot/core/provider/sources/nvidia_rerank_source.py new file mode 100644 index 0000000000..c168da4a6e --- /dev/null +++ b/astrbot/core/provider/sources/nvidia_rerank_source.py @@ -0,0 +1,164 @@ +import aiohttp + +from astrbot import logger + +from ..entities import ProviderType, RerankResult +from ..provider import RerankProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "nvidia_rerank", "NVIDIA Rerank 适配器", provider_type=ProviderType.RERANK +) +class NvidiaRerankProvider(RerankProvider): + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.api_key = provider_config.get("nvidia_rerank_api_key", "") + self.base_url = provider_config.get( + "nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval" + ).rstrip("/") + self.timeout = provider_config.get("timeout", 20) + self.model = provider_config.get( + "nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1" + ) + self.model_endpoint = provider_config.get( + "nvidia_rerank_model_endpoint", "/reranking" + ) + self.truncate = provider_config.get("nvidia_rerank_truncate", "") + + self.client = None + self.set_model(self.model) + + async def _get_client(self): + if self.client is None or self.client.closed: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + self.client = aiohttp.ClientSession( + headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + return self.client + + def _get_endpoint(self) -> str: + """ + 构建完整API URL。 + + 根据 Nvidia Rerank API 文档来看,当前URL存在不同模型格式不一致的问题。 + 这里针对模型名做一个基础判断用以适配,后续要等Nvidia统一API格式后再做调整。 + + 例: + 模型: nv-rerank-qa-mistral-4b:1 + URL: .../v1/retrieval/nvidia/reranking + + 模型: nvidia/llama-nemotron-rerank-1b-v2 + URL: .../v1/retrieval/nvidia/llama-nemotron-rerank-1b-v2/reranking + """ + + model_path = "nvidia" + logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}") + if "/" in self.model: + """遵循NVIDIA API的URL规则,替换模型名中特殊字符""" + model_path = self.model.strip("/").replace(".", "_") + endpoint = self.model_endpoint.lstrip("/") + return f"{self.base_url}/{model_path}/{endpoint}" + + def _build_payload(self, query: str, documents: list[str]) -> dict: + """构建请求载荷""" + payload = { + "model": self.model, + "query": {"text": query}, + "passages": [{"text": doc} for doc in documents], + } + if self.truncate: + payload["truncate"] = self.truncate + return payload + + def _parse_results( + self, response_data: dict, top_n: int | None + ) -> list[RerankResult]: + """解析响应数据""" + results = response_data.get("rankings", []) + if not results: + logger.warning(f"[NVIDIA Rerank] Empty response: {response_data}") + return [] + + rerank_results = [] + for idx, item in enumerate(results): + try: + index = item.get("index", idx) + score = item.get("relevance_score", item.get("logit", 0.0)) + rerank_results.append( + RerankResult(index=index, relevance_score=float(score)) + ) + except Exception as e: + logger.warning( + f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}" + ) + + rerank_results.sort(key=lambda x: x.relevance_score, reverse=True) + + if top_n is not None and top_n > 0: + return rerank_results[:top_n] + return rerank_results + + def _log_usage(self, data: dict) -> None: + usage = data.get("usage", {}) + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + logger.debug(f"[NVIDIA Rerank] Token Usage: {total_tokens}") + + async def rerank( + self, + query: str, + documents: list[str], + top_n: int | None = None, + ) -> list[RerankResult]: + client = await self._get_client() + if not client or client.closed: + logger.error("[NVIDIA Rerank] Client session not initialized or closed") + return [] + + if not documents or not query.strip(): + logger.warning( + "[NVIDIA Rerank] Input data is invalid, query or documents are empty" + ) + return [] + + try: + payload = self._build_payload(query, documents) + request_url = self._get_endpoint() + + async with client.post(request_url, json=payload) as response: + if response.status != 200: + try: + response_data = await response.json() + error_detail = response_data.get( + "detail", response_data.get("message", "Unknown Error") + ) + + except Exception: + error_detail = await response.text() + response_data = {"message": error_detail} + + logger.error(f"[NVIDIA Rerank] API Error Response: {response_data}") + raise Exception(f"HTTP {response.status} - {error_detail}") + + response_data = await response.json() + logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") + results = self._parse_results(response_data, top_n) + self._log_usage(response_data) + return results + + except aiohttp.ClientError as e: + logger.error(f"[NVIDIA Rerank] Network error: {e}") + raise Exception(f"Network error: {e}") from e + except Exception as e: + logger.error(f"[NVIDIA Rerank] Error: {e}") + raise Exception(f"Rerank error: {e}") from e + + async def terminate(self) -> None: + if self.client and not self.client.closed: + await self.client.close() + self.client = None diff --git a/dashboard/src/components/provider/AddNewProvider.vue b/dashboard/src/components/provider/AddNewProvider.vue index dfef836ab3..ac9f5e4412 100644 --- a/dashboard/src/components/provider/AddNewProvider.vue +++ b/dashboard/src/components/provider/AddNewProvider.vue @@ -1,5 +1,5 @@