Skip to content

Commit b0b6816

Browse files
feat: add NVIDIA rerank provider support (#7227)
* feat: add Rerank API support for NVIDIA NIM - Add Rerank API support for NVIDIA NIM - Add related i18n support in en-US zh-CN * chore: format code * fix: replace illegal characters Replace illegal characters when building request model path. * fix: refactor client initialization method * fix: enhance response parsing * docs: add comment for model_path process * docs: add russia translation * feat: update AddNewProvider component to support current provider type and enhance provider icon mapping --------- Co-authored-by: Soulter <905617992@qq.com>
1 parent 224287e commit b0b6816

File tree

9 files changed

+294
-3
lines changed

9 files changed

+294
-3
lines changed

astrbot/core/config/default.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,19 @@ class ChatProviderTemplate(TypedDict):
18021802
"return_documents": False,
18031803
"instruct": "",
18041804
},
1805+
"NVIDIA Rerank": {
1806+
"id": "nvidia_rerank",
1807+
"type": "nvidia_rerank",
1808+
"provider": "nvidia",
1809+
"provider_type": "rerank",
1810+
"enable": True,
1811+
"nvidia_rerank_api_key": "",
1812+
"nvidia_rerank_api_base": "https://ai.api.nvidia.com/v1/retrieval",
1813+
"nvidia_rerank_model": "nv-rerank-qa-mistral-4b:1",
1814+
"nvidia_rerank_model_endpoint": "/reranking",
1815+
"timeout": 20,
1816+
"nvidia_rerank_truncate": "",
1817+
},
18051818
"Xinference STT": {
18061819
"id": "xinference_stt",
18071820
"type": "xinference_stt",
@@ -1870,6 +1883,34 @@ class ChatProviderTemplate(TypedDict):
18701883
"type": "bool",
18711884
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
18721885
},
1886+
"nvidia_rerank_api_base": {
1887+
"description": "API Base URL",
1888+
"type": "string",
1889+
},
1890+
"nvidia_rerank_api_key": {
1891+
"description": "API Key",
1892+
"type": "string",
1893+
},
1894+
"nvidia_rerank_model": {
1895+
"description": "重排序模型名称",
1896+
"type": "string",
1897+
"hint": "请参照NVIDIA Docs中模型名称填写。",
1898+
},
1899+
"nvidia_rerank_model_endpoint": {
1900+
"description": "自定义模型端点",
1901+
"type": "string",
1902+
"hint": "自定义URL末尾端点,默认为 /reranking",
1903+
},
1904+
"nvidia_rerank_truncate": {
1905+
"description": "文本截断策略",
1906+
"type": "string",
1907+
"hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。",
1908+
"options": [
1909+
"",
1910+
"NONE",
1911+
"END",
1912+
],
1913+
},
18731914
"modalities": {
18741915
"description": "模型能力",
18751916
"type": "list",

astrbot/core/provider/manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,10 @@ def dynamic_import_provider(self, type: str) -> None:
477477
from .sources.bailian_rerank_source import (
478478
BailianRerankProvider as BailianRerankProvider,
479479
)
480+
case "nvidia_rerank":
481+
from .sources.nvidia_rerank_source import (
482+
NvidiaRerankProvider as NvidiaRerankProvider,
483+
)
480484

481485
def get_merged_provider_config(self, provider_config: dict) -> dict:
482486
"""获取 provider 配置和 provider_source 配置合并后的结果
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import aiohttp
2+
3+
from astrbot import logger
4+
5+
from ..entities import ProviderType, RerankResult
6+
from ..provider import RerankProvider
7+
from ..register import register_provider_adapter
8+
9+
10+
@register_provider_adapter(
11+
"nvidia_rerank", "NVIDIA Rerank 适配器", provider_type=ProviderType.RERANK
12+
)
13+
class NvidiaRerankProvider(RerankProvider):
14+
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
15+
super().__init__(provider_config, provider_settings)
16+
self.api_key = provider_config.get("nvidia_rerank_api_key", "")
17+
self.base_url = provider_config.get(
18+
"nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval"
19+
).rstrip("/")
20+
self.timeout = provider_config.get("timeout", 20)
21+
self.model = provider_config.get(
22+
"nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1"
23+
)
24+
self.model_endpoint = provider_config.get(
25+
"nvidia_rerank_model_endpoint", "/reranking"
26+
)
27+
self.truncate = provider_config.get("nvidia_rerank_truncate", "")
28+
29+
self.client = None
30+
self.set_model(self.model)
31+
32+
async def _get_client(self):
33+
if self.client is None or self.client.closed:
34+
headers = {
35+
"Authorization": f"Bearer {self.api_key}",
36+
"Content-Type": "application/json",
37+
"Accept": "application/json",
38+
}
39+
self.client = aiohttp.ClientSession(
40+
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
41+
)
42+
return self.client
43+
44+
def _get_endpoint(self) -> str:
45+
"""
46+
构建完整API URL。
47+
48+
根据 Nvidia Rerank API 文档来看,当前URL存在不同模型格式不一致的问题。
49+
这里针对模型名做一个基础判断用以适配,后续要等Nvidia统一API格式后再做调整。
50+
51+
例:
52+
模型: nv-rerank-qa-mistral-4b:1
53+
URL: .../v1/retrieval/nvidia/reranking
54+
55+
模型: nvidia/llama-nemotron-rerank-1b-v2
56+
URL: .../v1/retrieval/nvidia/llama-nemotron-rerank-1b-v2/reranking
57+
"""
58+
59+
model_path = "nvidia"
60+
logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}")
61+
if "/" in self.model:
62+
"""遵循NVIDIA API的URL规则,替换模型名中特殊字符"""
63+
model_path = self.model.strip("/").replace(".", "_")
64+
endpoint = self.model_endpoint.lstrip("/")
65+
return f"{self.base_url}/{model_path}/{endpoint}"
66+
67+
def _build_payload(self, query: str, documents: list[str]) -> dict:
68+
"""构建请求载荷"""
69+
payload = {
70+
"model": self.model,
71+
"query": {"text": query},
72+
"passages": [{"text": doc} for doc in documents],
73+
}
74+
if self.truncate:
75+
payload["truncate"] = self.truncate
76+
return payload
77+
78+
def _parse_results(
79+
self, response_data: dict, top_n: int | None
80+
) -> list[RerankResult]:
81+
"""解析响应数据"""
82+
results = response_data.get("rankings", [])
83+
if not results:
84+
logger.warning(f"[NVIDIA Rerank] Empty response: {response_data}")
85+
return []
86+
87+
rerank_results = []
88+
for idx, item in enumerate(results):
89+
try:
90+
index = item.get("index", idx)
91+
score = item.get("relevance_score", item.get("logit", 0.0))
92+
rerank_results.append(
93+
RerankResult(index=index, relevance_score=float(score))
94+
)
95+
except Exception as e:
96+
logger.warning(
97+
f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}"
98+
)
99+
100+
rerank_results.sort(key=lambda x: x.relevance_score, reverse=True)
101+
102+
if top_n is not None and top_n > 0:
103+
return rerank_results[:top_n]
104+
return rerank_results
105+
106+
def _log_usage(self, data: dict) -> None:
107+
usage = data.get("usage", {})
108+
total_tokens = usage.get("total_tokens", 0)
109+
if total_tokens > 0:
110+
logger.debug(f"[NVIDIA Rerank] Token Usage: {total_tokens}")
111+
112+
async def rerank(
113+
self,
114+
query: str,
115+
documents: list[str],
116+
top_n: int | None = None,
117+
) -> list[RerankResult]:
118+
client = await self._get_client()
119+
if not client or client.closed:
120+
logger.error("[NVIDIA Rerank] Client session not initialized or closed")
121+
return []
122+
123+
if not documents or not query.strip():
124+
logger.warning(
125+
"[NVIDIA Rerank] Input data is invalid, query or documents are empty"
126+
)
127+
return []
128+
129+
try:
130+
payload = self._build_payload(query, documents)
131+
request_url = self._get_endpoint()
132+
133+
async with client.post(request_url, json=payload) as response:
134+
if response.status != 200:
135+
try:
136+
response_data = await response.json()
137+
error_detail = response_data.get(
138+
"detail", response_data.get("message", "Unknown Error")
139+
)
140+
141+
except Exception:
142+
error_detail = await response.text()
143+
response_data = {"message": error_detail}
144+
145+
logger.error(f"[NVIDIA Rerank] API Error Response: {response_data}")
146+
raise Exception(f"HTTP {response.status} - {error_detail}")
147+
148+
response_data = await response.json()
149+
logger.debug(f"[NVIDIA Rerank] API Response: {response_data}")
150+
results = self._parse_results(response_data, top_n)
151+
self._log_usage(response_data)
152+
return results
153+
154+
except aiohttp.ClientError as e:
155+
logger.error(f"[NVIDIA Rerank] Network error: {e}")
156+
raise Exception(f"Network error: {e}") from e
157+
except Exception as e:
158+
logger.error(f"[NVIDIA Rerank] Error: {e}")
159+
raise Exception(f"Rerank error: {e}") from e
160+
161+
async def terminate(self) -> None:
162+
if self.client and not self.client.closed:
163+
await self.client.close()
164+
self.client = None

dashboard/src/components/provider/AddNewProvider.vue

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
<template>
2-
<v-dialog v-model="showDialog" max-width="1100px" min-height="95%">
2+
<v-dialog v-model="showDialog" max-width="1000px" >
33
<v-card :title="tm('dialogs.addProvider.title')">
44
<v-card-text style="overflow-y: auto;">
55
<v-tabs v-model="activeProviderTab" grow>
@@ -73,6 +73,8 @@
7373
import { useModuleI18n } from '@/i18n/composables';
7474
import { getProviderIcon, getProviderDescription } from '@/utils/providerUtils';
7575
76+
const AVAILABLE_PROVIDER_TABS = ['agent_runner', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank'];
77+
7678
export default {
7779
name: 'AddNewProvider',
7880
props: {
@@ -83,6 +85,10 @@ export default {
8385
metadata: {
8486
type: Object,
8587
default: () => ({})
88+
},
89+
currentProviderType: {
90+
type: String,
91+
default: 'agent_runner'
8692
}
8793
},
8894
emits: ['update:show', 'select-template'],
@@ -92,7 +98,7 @@ export default {
9298
},
9399
data() {
94100
return {
95-
activeProviderTab: 'chat_completion'
101+
activeProviderTab: 'agent_runner'
96102
};
97103
},
98104
computed: {
@@ -105,7 +111,25 @@ export default {
105111
}
106112
},
107113
},
114+
watch: {
115+
show(value) {
116+
if (value) {
117+
this.syncActiveProviderTab();
118+
}
119+
},
120+
currentProviderType() {
121+
if (this.showDialog) {
122+
this.syncActiveProviderTab();
123+
}
124+
}
125+
},
108126
methods: {
127+
syncActiveProviderTab() {
128+
this.activeProviderTab = AVAILABLE_PROVIDER_TABS.includes(this.currentProviderType)
129+
? this.currentProviderType
130+
: 'agent_runner';
131+
},
132+
109133
closeDialog() {
110134
this.showDialog = false;
111135
},

dashboard/src/i18n/locales/en-US/features/config-metadata.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,6 +1096,24 @@
10961096
"description": "Custom rerank task description",
10971097
"hint": "Only effective for qwen3-rerank models. Recommended to write in English."
10981098
},
1099+
"nvidia_rerank_api_base": {
1100+
"description": "API Base URL"
1101+
},
1102+
"nvidia_rerank_api_key": {
1103+
"description": "API Key"
1104+
},
1105+
"nvidia_rerank_model": {
1106+
"description": "Rerank Model Name",
1107+
"hint": "Please refer to the NVIDIA Docs for the model name."
1108+
},
1109+
"nvidia_rerank_model_endpoint": {
1110+
"description": "Custom Model Endpoint",
1111+
"hint": "Custom URL suffix endpoint, defaults to /reranking."
1112+
},
1113+
"nvidia_rerank_truncate": {
1114+
"description": "Text Truncation Strategy",
1115+
"hint": "Whether to truncate the input to fit the model's maximum context length when the input text is too long."
1116+
},
10991117
"launch_model_if_not_running": {
11001118
"description": "Auto-start model if not running",
11011119
"hint": "If the model is not running in Xinference, attempt to start it automatically. Recommended to disable in production."

dashboard/src/i18n/locales/ru-RU/features/config-metadata.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,24 @@
10971097
"description": "Описание задачи для Rerank",
10981098
"hint": "Эффективно только для моделей qwen3-rerank. Рекомендуется писать на английском."
10991099
},
1100+
"nvidia_rerank_api_base": {
1101+
"description": "Базовый URL API"
1102+
},
1103+
"nvidia_rerank_api_key": {
1104+
"description": "API-ключ"
1105+
},
1106+
"nvidia_rerank_model": {
1107+
"description": "Название модели Rerank",
1108+
"hint": "Укажите название модели в соответствии с документацией NVIDIA."
1109+
},
1110+
"nvidia_rerank_model_endpoint": {
1111+
"description": "Пользовательский endpoint модели",
1112+
"hint": "Пользовательский суффикс URL endpoint, по умолчанию /reranking."
1113+
},
1114+
"nvidia_rerank_truncate": {
1115+
"description": "Стратегия усечения текста",
1116+
"hint": "Определяет, следует ли усекать входной текст, если он слишком длинный и не помещается в максимальную длину контекста модели."
1117+
},
11001118
"launch_model_if_not_running": {
11011119
"description": "Автозапуск модели",
11021120
"hint": "Если модель не запущена в Xinference, попытаться запустить её автоматически. Рекомендуется отключать в продакшене."

dashboard/src/i18n/locales/zh-CN/features/config-metadata.json

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,24 @@
10981098
"description": "自定义排序任务类型说明",
10991099
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。"
11001100
},
1101+
"nvidia_rerank_api_base": {
1102+
"description": "API Base URL"
1103+
},
1104+
"nvidia_rerank_api_key": {
1105+
"description": "API Key"
1106+
},
1107+
"nvidia_rerank_model": {
1108+
"description": "重排序模型名称",
1109+
"hint": "请参照NVIDIA Docs中模型名称填写。"
1110+
},
1111+
"nvidia_rerank_model_endpoint": {
1112+
"description": "自定义模型端点",
1113+
"hint": "自定义URL末尾端点,默认为 /reranking"
1114+
},
1115+
"nvidia_rerank_truncate": {
1116+
"description": "文本截断策略",
1117+
"hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。"
1118+
},
11011119
"launch_model_if_not_running": {
11021120
"description": "模型未运行时自动启动",
11031121
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。"

dashboard/src/utils/providerUtils.js

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,10 @@ export function getProviderIcon(type) {
4141
'aihubmix': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/aihubmix-color.svg',
4242
'openrouter': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/openrouter.svg',
4343
"tokenpony": "https://tokenpony.cn/tokenpony-web/logo.png",
44-
"compshare": "https://compshare.cn/favicon.ico"
44+
"compshare": "https://compshare.cn/favicon.ico",
45+
"xinference": "https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/xinference-color.svg",
46+
"bailian": "https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/bailian-color.svg",
47+
"volcengine": 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/volcengine-color.svg',
4548
};
4649
return icons[type] || '';
4750
}

dashboard/src/views/ProviderPage.vue

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@
175175

176176
<!-- 添加提供商对话框 -->
177177
<AddNewProvider v-model:show="showAddProviderDialog" :metadata="configSchema"
178+
:current-provider-type="selectedProviderType"
178179
@select-template="selectProviderTemplate" />
179180

180181
<!-- 手动添加模型对话框 -->

0 commit comments

Comments
 (0)