From 6799c21dd36c5d85a076c6ce18849ae6fca91cc2 Mon Sep 17 00:00:00 2001 From: Chang Li <2747297361@qq.com> Date: Tue, 31 Mar 2026 12:20:05 +0800 Subject: [PATCH 1/8] feat: add Rerank API support for NVIDIA NIM - Add Rerank API support for NVIDIA NIM - Add related i18n support in en-US zh-CN --- astrbot/core/config/default.py | 41 ++++++ astrbot/core/provider/manager.py | 4 + .../provider/sources/nvidia_rerank_source.py | 139 ++++++++++++++++++ .../en-US/features/config-metadata.json | 18 +++ .../zh-CN/features/config-metadata.json | 18 +++ 5 files changed, 220 insertions(+) create mode 100644 astrbot/core/provider/sources/nvidia_rerank_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 846b1a350e..930dc78da5 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1789,6 +1789,19 @@ class ChatProviderTemplate(TypedDict): "return_documents": False, "instruct": "", }, + "Nvidia Rerank": { + "id": "nvidia_rerank", + "type": "nvidia_rerank", + "provider": "nvidia", + "provider_type": "rerank", + "enable": True, + "nvidia_rerank_api_key": "", + "nvidia_rerank_api_base": "https://ai.api.nvidia.com/v1/retrieval", + "nvidia_rerank_model": "nv-rerank-qa-mistral-4b:1", + "nvidia_rerank_model_endpoint": "/reranking", + "timeout": 20, + "nvidia_rerank_truncate": "", + }, "Xinference STT": { "id": "xinference_stt", "type": "xinference_stt", @@ -1852,6 +1865,34 @@ class ChatProviderTemplate(TypedDict): "type": "bool", "hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。", }, + "nvidia_rerank_api_base": { + "description": "API Base URL", + "type": "string", + }, + "nvidia_rerank_api_key": { + "description": "API Key", + "type": "string", + }, + "nvidia_rerank_model": { + "description": "重排序模型名称", + "type": "string", + "hint": "请参照NVIDIA Docs中模型名称填写。", + }, + "nvidia_rerank_model_endpoint": { + "description": "自定义模型端点", + "type": "string", + "hint": "自定义URL末尾端点,默认为 /reranking", + }, + "nvidia_rerank_truncate": { + "description": "文本截断策略", + "type": "string", + "hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。", + "options": [ + "", + "NONE", + "END", + ], + }, "modalities": { "description": "模型能力", "type": "list", diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 7a3e1543a7..dfed10261d 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -471,6 +471,10 @@ def dynamic_import_provider(self, type: str) -> None: from .sources.bailian_rerank_source import ( BailianRerankProvider as BailianRerankProvider, ) + case "nvidia_rerank": + from .sources.nvidia_rerank_source import ( + NvidiaRerankProvider as NvidiaRerankProvider, + ) def get_merged_provider_config(self, provider_config: dict) -> dict: """获取 provider 配置和 provider_source 配置合并后的结果 diff --git a/astrbot/core/provider/sources/nvidia_rerank_source.py b/astrbot/core/provider/sources/nvidia_rerank_source.py new file mode 100644 index 0000000000..6103f76646 --- /dev/null +++ b/astrbot/core/provider/sources/nvidia_rerank_source.py @@ -0,0 +1,139 @@ +import aiohttp + +from astrbot import logger + +from ..entities import ProviderType, RerankResult +from ..provider import RerankProvider +from ..register import register_provider_adapter + + +@register_provider_adapter( + "nvidia_rerank", + "NVIDIA Rerank 适配器", + provider_type=ProviderType.RERANK +) +class NvidiaRerankProvider(RerankProvider): + + def __init__(self, provider_config: dict, provider_settings: dict) -> None: + super().__init__(provider_config, provider_settings) + self.api_key = provider_config.get("nvidia_rerank_api_key", "") + self.base_url = provider_config.get("nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval").rstrip("/") + self.timeout = provider_config.get("timeout", 20) + self.model = provider_config.get("nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1") + self.model_endpoint = provider_config.get("nvidia_rerank_model_endpoint", "/reranking") + self.truncate = provider_config.get("nvidia_rerank_truncate", "") + + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + + self.client = aiohttp.ClientSession( + headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + + self.set_model(self.model) + + def _get_endpoint(self) -> str: + """ + 构建完整API URL。 + + 根据 Nvidia Rerank API 文档来看,当前URL存在不同模型格式不一致的问题。 + 这里针对模型名做一个基础判断用以适配,后续要等Nvidia统一API格式后再做调整。 + + 例: + 模型: nv-rerank-qa-mistral-4b:1 + URL: .../v1/retrieval/nvidia/reranking + + 模型: nvidia/llama-nemotron-rerank-1b-v2 + URL: .../v1/retrieval/nvidia/llama-nemotron-rerank-1b-v2/reranking + """ + + model_path = "nvidia" + logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}") + if "/" in self.model: + model_path = self.model.strip("/") + endpoint = self.model_endpoint.lstrip("/") + return f"{self.base_url}/{model_path}/{endpoint}" + + def _build_payload(self, query: str, documents: list[str]) -> dict: + """构建请求载荷""" + payload = { + "model": self.model, + "query": {"text": query}, + "passages": [{"text": doc} for doc in documents], + } + if self.truncate: + payload["truncate"] = self.truncate + return payload + + def _parse_results(self, response_data: dict, top_n: int | None) -> list[RerankResult]: + """解析响应数据""" + results = response_data.get("rankings", []) + if not results: + logger.warning(f"[NVIDIA Rerank] Empty response: {response_data}") + return [] + + rerank_results = [] + for idx, item in enumerate(results): + try: + index = item.get("index", idx) + score = item.get("relevance_score", item.get("logit", 0.0)) + rerank_results.append(RerankResult(index=index, relevance_score=float(score))) + except Exception as e: + logger.warning(f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}") + + rerank_results.sort(key=lambda x: x.relevance_score, reverse=True) + + if top_n is not None and top_n > 0: + return rerank_results[:top_n] + return rerank_results + + def _log_usage(self, data: dict) -> None: + usage = data.get("usage", {}) + total_tokens = usage.get("total_tokens", 0) + if total_tokens > 0: + logger.debug(f"[NVIDIA Rerank] Token Usage: {total_tokens}") + + async def rerank( + self, + query: str, + documents: list[str], + top_n: int | None = None, + ) -> list[RerankResult]: + if not self.client or self.client.closed: + logger.error("[NVIDIA Rerank] Client session not initialized or closed") + return [] + + if not documents or not query.strip(): + logger.warning("[NVIDIA Rerank] Input data is invalid, query or documents are empty") + return [] + + try: + payload = self._build_payload(query, documents) + request_url = self._get_endpoint() + + async with self.client.post(request_url, json=payload) as response: + response_data = await response.json() + logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") + + if response.status != 200: + error_detail = response_data.get("detail", response_data.get("message", "Unknown Error")) + raise Exception(f"HTTP {response.status} - {error_detail}") + + results = self._parse_results(response_data, top_n) + self._log_usage(response_data) + return results + + except aiohttp.ClientError as e: + logger.error(f"[NVIDIA Rerank] Network error: {e}") + raise Exception(f"Network error: {e}") from e + except Exception as e: + logger.error(f"[NVIDIA Rerank] Error: {e}") + raise Exception(f"Rerank error: {e}") from e + + async def terminate(self) -> None: + if self.client and not self.client.closed: + await self.client.close() + self.client = None diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index 9ae8672826..37a5e37504 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -1092,6 +1092,24 @@ "description": "Custom rerank task description", "hint": "Only effective for qwen3-rerank models. Recommended to write in English." }, + "nvidia_rerank_api_base": { + "description": "API Base URL" + }, + "nvidia_rerank_api_key": { + "description": "API Key" + }, + "nvidia_rerank_model": { + "description": "Rerank Model Name", + "hint": "Please refer to the NVIDIA Docs for the model name." + }, + "nvidia_rerank_model_endpoint": { + "description": "Custom Model Endpoint", + "hint": "Custom URL suffix endpoint, defaults to /reranking." + }, + "nvidia_rerank_truncate": { + "description": "Text Truncation Strategy", + "hint": "Whether to truncate the input to fit the model's maximum context length when the input text is too long." + }, "launch_model_if_not_running": { "description": "Auto-start model if not running", "hint": "If the model is not running in Xinference, attempt to start it automatically. Recommended to disable in production." diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index c04138402e..39eab21d13 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -1094,6 +1094,24 @@ "description": "自定义排序任务类型说明", "hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。" }, + "nvidia_rerank_api_base": { + "description": "API Base URL" + }, + "nvidia_rerank_api_key": { + "description": "API Key" + }, + "nvidia_rerank_model": { + "description": "重排序模型名称", + "hint": "请参照NVIDIA Docs中模型名称填写。" + }, + "nvidia_rerank_model_endpoint": { + "description": "自定义模型端点", + "hint": "自定义URL末尾端点,默认为 /reranking" + }, + "nvidia_rerank_truncate": { + "description": "文本截断策略", + "hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。" + }, "launch_model_if_not_running": { "description": "模型未运行时自动启动", "hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。" From 8b9960d1334de2a53aa9dc2a619d8182dd8c37d2 Mon Sep 17 00:00:00 2001 From: Chang Li <2747297361@qq.com> Date: Tue, 31 Mar 2026 15:27:28 +0800 Subject: [PATCH 2/8] chore: format code --- .../provider/sources/nvidia_rerank_source.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/astrbot/core/provider/sources/nvidia_rerank_source.py b/astrbot/core/provider/sources/nvidia_rerank_source.py index 6103f76646..b6d348325c 100644 --- a/astrbot/core/provider/sources/nvidia_rerank_source.py +++ b/astrbot/core/provider/sources/nvidia_rerank_source.py @@ -8,19 +8,22 @@ @register_provider_adapter( - "nvidia_rerank", - "NVIDIA Rerank 适配器", - provider_type=ProviderType.RERANK + "nvidia_rerank", "NVIDIA Rerank 适配器", provider_type=ProviderType.RERANK ) class NvidiaRerankProvider(RerankProvider): - def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config, provider_settings) self.api_key = provider_config.get("nvidia_rerank_api_key", "") - self.base_url = provider_config.get("nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval").rstrip("/") + self.base_url = provider_config.get( + "nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval" + ).rstrip("/") self.timeout = provider_config.get("timeout", 20) - self.model = provider_config.get("nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1") - self.model_endpoint = provider_config.get("nvidia_rerank_model_endpoint", "/reranking") + self.model = provider_config.get( + "nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1" + ) + self.model_endpoint = provider_config.get( + "nvidia_rerank_model_endpoint", "/reranking" + ) self.truncate = provider_config.get("nvidia_rerank_truncate", "") headers = { @@ -68,7 +71,9 @@ def _build_payload(self, query: str, documents: list[str]) -> dict: payload["truncate"] = self.truncate return payload - def _parse_results(self, response_data: dict, top_n: int | None) -> list[RerankResult]: + def _parse_results( + self, response_data: dict, top_n: int | None + ) -> list[RerankResult]: """解析响应数据""" results = response_data.get("rankings", []) if not results: @@ -80,9 +85,13 @@ def _parse_results(self, response_data: dict, top_n: int | None) -> list[RerankR try: index = item.get("index", idx) score = item.get("relevance_score", item.get("logit", 0.0)) - rerank_results.append(RerankResult(index=index, relevance_score=float(score))) + rerank_results.append( + RerankResult(index=index, relevance_score=float(score)) + ) except Exception as e: - logger.warning(f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}") + logger.warning( + f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}" + ) rerank_results.sort(key=lambda x: x.relevance_score, reverse=True) @@ -107,7 +116,9 @@ async def rerank( return [] if not documents or not query.strip(): - logger.warning("[NVIDIA Rerank] Input data is invalid, query or documents are empty") + logger.warning( + "[NVIDIA Rerank] Input data is invalid, query or documents are empty" + ) return [] try: @@ -119,7 +130,9 @@ async def rerank( logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") if response.status != 200: - error_detail = response_data.get("detail", response_data.get("message", "Unknown Error")) + error_detail = response_data.get( + "detail", response_data.get("message", "Unknown Error") + ) raise Exception(f"HTTP {response.status} - {error_detail}") results = self._parse_results(response_data, top_n) From e1088afb93d397b510095f4f0a120d8c6137e3f9 Mon Sep 17 00:00:00 2001 From: Chang Li <2747297361@qq.com> Date: Tue, 31 Mar 2026 16:13:24 +0800 Subject: [PATCH 3/8] fix: replace illegal characters Replace illegal characters when building request model path. --- astrbot/core/provider/sources/nvidia_rerank_source.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/astrbot/core/provider/sources/nvidia_rerank_source.py b/astrbot/core/provider/sources/nvidia_rerank_source.py index b6d348325c..aefc7b9d52 100644 --- a/astrbot/core/provider/sources/nvidia_rerank_source.py +++ b/astrbot/core/provider/sources/nvidia_rerank_source.py @@ -56,7 +56,7 @@ def _get_endpoint(self) -> str: model_path = "nvidia" logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}") if "/" in self.model: - model_path = self.model.strip("/") + model_path = self.model.strip("/").replace(".", "_") endpoint = self.model_endpoint.lstrip("/") return f"{self.base_url}/{model_path}/{endpoint}" From 6a4c26d6a892c583572787b8eeb52d99c29edc23 Mon Sep 17 00:00:00 2001 From: Chang Li <2747297361@qq.com> Date: Tue, 31 Mar 2026 17:26:42 +0800 Subject: [PATCH 4/8] fix: refactor client initialization method --- .../provider/sources/nvidia_rerank_source.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/astrbot/core/provider/sources/nvidia_rerank_source.py b/astrbot/core/provider/sources/nvidia_rerank_source.py index aefc7b9d52..d5c538ec70 100644 --- a/astrbot/core/provider/sources/nvidia_rerank_source.py +++ b/astrbot/core/provider/sources/nvidia_rerank_source.py @@ -26,18 +26,21 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: ) self.truncate = provider_config.get("nvidia_rerank_truncate", "") - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "Accept": "application/json", - } - - self.client = aiohttp.ClientSession( - headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) - ) - + self.client = None self.set_model(self.model) + async def _get_client(self): + if self.client is None or self.client.closed: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json", + } + self.client = aiohttp.ClientSession( + headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) + ) + return self.client + def _get_endpoint(self) -> str: """ 构建完整API URL。 @@ -111,7 +114,8 @@ async def rerank( documents: list[str], top_n: int | None = None, ) -> list[RerankResult]: - if not self.client or self.client.closed: + client = await self._get_client() + if not client or client.closed: logger.error("[NVIDIA Rerank] Client session not initialized or closed") return [] @@ -125,7 +129,7 @@ async def rerank( payload = self._build_payload(query, documents) request_url = self._get_endpoint() - async with self.client.post(request_url, json=payload) as response: + async with client.post(request_url, json=payload) as response: response_data = await response.json() logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") From 5cb7e012d46884059d1b032ad05f723e13f19d26 Mon Sep 17 00:00:00 2001 From: Chang Li <2747297361@qq.com> Date: Tue, 31 Mar 2026 17:29:26 +0800 Subject: [PATCH 5/8] fix: enhance response parsing --- .../provider/sources/nvidia_rerank_source.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/astrbot/core/provider/sources/nvidia_rerank_source.py b/astrbot/core/provider/sources/nvidia_rerank_source.py index d5c538ec70..19b415fabf 100644 --- a/astrbot/core/provider/sources/nvidia_rerank_source.py +++ b/astrbot/core/provider/sources/nvidia_rerank_source.py @@ -130,15 +130,22 @@ async def rerank( request_url = self._get_endpoint() async with client.post(request_url, json=payload) as response: - response_data = await response.json() - logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") - if response.status != 200: - error_detail = response_data.get( - "detail", response_data.get("message", "Unknown Error") - ) + try: + response_data = await response.json() + error_detail = response_data.get( + "detail", response_data.get("message", "Unknown Error") + ) + + except Exception: + error_detail = await response.text() + response_data = {"message": error_detail} + + logger.error(f"[NVIDIA Rerank] API Error Response: {response_data}") raise Exception(f"HTTP {response.status} - {error_detail}") + response_data = await response.json() + logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") results = self._parse_results(response_data, top_n) self._log_usage(response_data) return results From e0501c20cd6117dc1e4af15a567a1aa7a4a5bfa9 Mon Sep 17 00:00:00 2001 From: Chang Li <2747297361@qq.com> Date: Tue, 31 Mar 2026 17:37:58 +0800 Subject: [PATCH 6/8] docs: add comment for model_path process --- astrbot/core/provider/sources/nvidia_rerank_source.py | 1 + 1 file changed, 1 insertion(+) diff --git a/astrbot/core/provider/sources/nvidia_rerank_source.py b/astrbot/core/provider/sources/nvidia_rerank_source.py index 19b415fabf..c168da4a6e 100644 --- a/astrbot/core/provider/sources/nvidia_rerank_source.py +++ b/astrbot/core/provider/sources/nvidia_rerank_source.py @@ -59,6 +59,7 @@ def _get_endpoint(self) -> str: model_path = "nvidia" logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}") if "/" in self.model: + """遵循NVIDIA API的URL规则,替换模型名中特殊字符""" model_path = self.model.strip("/").replace(".", "_") endpoint = self.model_endpoint.lstrip("/") return f"{self.base_url}/{model_path}/{endpoint}" From 96da8f9ae6abf20df365992fdfc2b1d5d23dc51f Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 6 Apr 2026 15:07:10 +0800 Subject: [PATCH 7/8] docs: add russia translation --- astrbot/core/config/default.py | 2 +- .../ru-RU/features/config-metadata.json | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 930dc78da5..003359064f 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1789,7 +1789,7 @@ class ChatProviderTemplate(TypedDict): "return_documents": False, "instruct": "", }, - "Nvidia Rerank": { + "NVIDIA Rerank": { "id": "nvidia_rerank", "type": "nvidia_rerank", "provider": "nvidia", diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index 0aa5c791ac..257bc68540 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -1093,6 +1093,24 @@ "description": "Описание задачи для Rerank", "hint": "Эффективно только для моделей qwen3-rerank. Рекомендуется писать на английском." }, + "nvidia_rerank_api_base": { + "description": "Базовый URL API" + }, + "nvidia_rerank_api_key": { + "description": "API-ключ" + }, + "nvidia_rerank_model": { + "description": "Название модели Rerank", + "hint": "Укажите название модели в соответствии с документацией NVIDIA." + }, + "nvidia_rerank_model_endpoint": { + "description": "Пользовательский endpoint модели", + "hint": "Пользовательский суффикс URL endpoint, по умолчанию /reranking." + }, + "nvidia_rerank_truncate": { + "description": "Стратегия усечения текста", + "hint": "Определяет, следует ли усекать входной текст, если он слишком длинный и не помещается в максимальную длину контекста модели." + }, "launch_model_if_not_running": { "description": "Автозапуск модели", "hint": "Если модель не запущена в Xinference, попытаться запустить её автоматически. Рекомендуется отключать в продакшене." From 76470e2126ca3de538433c79a32e1c651c493d4c Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Mon, 6 Apr 2026 15:18:42 +0800 Subject: [PATCH 8/8] feat: update AddNewProvider component to support current provider type and enhance provider icon mapping --- .../components/provider/AddNewProvider.vue | 28 +++++++++++++++++-- dashboard/src/utils/providerUtils.js | 5 +++- dashboard/src/views/ProviderPage.vue | 1 + 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/dashboard/src/components/provider/AddNewProvider.vue b/dashboard/src/components/provider/AddNewProvider.vue index dfef836ab3..ac9f5e4412 100644 --- a/dashboard/src/components/provider/AddNewProvider.vue +++ b/dashboard/src/components/provider/AddNewProvider.vue @@ -1,5 +1,5 @@