Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions astrbot/core/config/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -1789,6 +1789,19 @@ class ChatProviderTemplate(TypedDict):
"return_documents": False,
"instruct": "",
},
"NVIDIA Rerank": {
"id": "nvidia_rerank",
"type": "nvidia_rerank",
"provider": "nvidia",
"provider_type": "rerank",
"enable": True,
"nvidia_rerank_api_key": "",
"nvidia_rerank_api_base": "https://ai.api.nvidia.com/v1/retrieval",
"nvidia_rerank_model": "nv-rerank-qa-mistral-4b:1",
"nvidia_rerank_model_endpoint": "/reranking",
"timeout": 20,
"nvidia_rerank_truncate": "",
},
"Xinference STT": {
"id": "xinference_stt",
"type": "xinference_stt",
Expand Down Expand Up @@ -1852,6 +1865,34 @@ class ChatProviderTemplate(TypedDict):
"type": "bool",
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。",
},
"nvidia_rerank_api_base": {
"description": "API Base URL",
"type": "string",
},
"nvidia_rerank_api_key": {
"description": "API Key",
"type": "string",
},
"nvidia_rerank_model": {
"description": "重排序模型名称",
"type": "string",
"hint": "请参照NVIDIA Docs中模型名称填写。",
},
"nvidia_rerank_model_endpoint": {
"description": "自定义模型端点",
"type": "string",
"hint": "自定义URL末尾端点,默认为 /reranking",
},
"nvidia_rerank_truncate": {
"description": "文本截断策略",
"type": "string",
"hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。",
"options": [
"",
"NONE",
"END",
],
},
"modalities": {
"description": "模型能力",
"type": "list",
Expand Down
4 changes: 4 additions & 0 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,10 @@ def dynamic_import_provider(self, type: str) -> None:
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
case "nvidia_rerank":
from .sources.nvidia_rerank_source import (
NvidiaRerankProvider as NvidiaRerankProvider,
)

def get_merged_provider_config(self, provider_config: dict) -> dict:
"""获取 provider 配置和 provider_source 配置合并后的结果
Expand Down
164 changes: 164 additions & 0 deletions astrbot/core/provider/sources/nvidia_rerank_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import aiohttp

from astrbot import logger

from ..entities import ProviderType, RerankResult
from ..provider import RerankProvider
from ..register import register_provider_adapter


@register_provider_adapter(
"nvidia_rerank", "NVIDIA Rerank 适配器", provider_type=ProviderType.RERANK
)
class NvidiaRerankProvider(RerankProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config, provider_settings)
self.api_key = provider_config.get("nvidia_rerank_api_key", "")
self.base_url = provider_config.get(
"nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval"
).rstrip("/")
self.timeout = provider_config.get("timeout", 20)
self.model = provider_config.get(
"nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1"
)
self.model_endpoint = provider_config.get(
"nvidia_rerank_model_endpoint", "/reranking"
)
self.truncate = provider_config.get("nvidia_rerank_truncate", "")

self.client = None
self.set_model(self.model)

async def _get_client(self):
if self.client is None or self.client.closed:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
}
self.client = aiohttp.ClientSession(
headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout)
)
return self.client

def _get_endpoint(self) -> str:
"""
构建完整API URL。

根据 Nvidia Rerank API 文档来看,当前URL存在不同模型格式不一致的问题。
这里针对模型名做一个基础判断用以适配,后续要等Nvidia统一API格式后再做调整。

例:
模型: nv-rerank-qa-mistral-4b:1
URL: .../v1/retrieval/nvidia/reranking

模型: nvidia/llama-nemotron-rerank-1b-v2
URL: .../v1/retrieval/nvidia/llama-nemotron-rerank-1b-v2/reranking
"""

model_path = "nvidia"
logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}")
if "/" in self.model:
"""遵循NVIDIA API的URL规则,替换模型名中特殊字符"""
model_path = self.model.strip("/").replace(".", "_")
endpoint = self.model_endpoint.lstrip("/")
return f"{self.base_url}/{model_path}/{endpoint}"

def _build_payload(self, query: str, documents: list[str]) -> dict:
"""构建请求载荷"""
payload = {
"model": self.model,
"query": {"text": query},
"passages": [{"text": doc} for doc in documents],
}
if self.truncate:
payload["truncate"] = self.truncate
return payload

def _parse_results(
self, response_data: dict, top_n: int | None
) -> list[RerankResult]:
"""解析响应数据"""
results = response_data.get("rankings", [])
if not results:
logger.warning(f"[NVIDIA Rerank] Empty response: {response_data}")
return []

rerank_results = []
for idx, item in enumerate(results):
try:
index = item.get("index", idx)
score = item.get("relevance_score", item.get("logit", 0.0))
rerank_results.append(
RerankResult(index=index, relevance_score=float(score))
)
except Exception as e:
logger.warning(
f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}"
)

rerank_results.sort(key=lambda x: x.relevance_score, reverse=True)

if top_n is not None and top_n > 0:
return rerank_results[:top_n]
return rerank_results

def _log_usage(self, data: dict) -> None:
usage = data.get("usage", {})
total_tokens = usage.get("total_tokens", 0)
if total_tokens > 0:
logger.debug(f"[NVIDIA Rerank] Token Usage: {total_tokens}")

async def rerank(
self,
query: str,
documents: list[str],
top_n: int | None = None,
) -> list[RerankResult]:
client = await self._get_client()
if not client or client.closed:
logger.error("[NVIDIA Rerank] Client session not initialized or closed")
return []

if not documents or not query.strip():
logger.warning(
"[NVIDIA Rerank] Input data is invalid, query or documents are empty"
)
return []

try:
payload = self._build_payload(query, documents)
request_url = self._get_endpoint()

async with client.post(request_url, json=payload) as response:
if response.status != 200:
try:
response_data = await response.json()
error_detail = response_data.get(
"detail", response_data.get("message", "Unknown Error")
)

except Exception:
error_detail = await response.text()
response_data = {"message": error_detail}

logger.error(f"[NVIDIA Rerank] API Error Response: {response_data}")
raise Exception(f"HTTP {response.status} - {error_detail}")
Comment on lines +137 to +146
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The error handling logic here might fail if the response body is consumed twice. When response.status != 200, await response.json() is called. If it fails (e.g., the response is not valid JSON), the except block calls await response.text(). While aiohttp usually allows reading the body multiple times if it's cached, it's safer to read the body once and then attempt to parse it.


response_data = await response.json()
logger.debug(f"[NVIDIA Rerank] API Response: {response_data}")
results = self._parse_results(response_data, top_n)
self._log_usage(response_data)
return results

except aiohttp.ClientError as e:
logger.error(f"[NVIDIA Rerank] Network error: {e}")
raise Exception(f"Network error: {e}") from e
except Exception as e:
logger.error(f"[NVIDIA Rerank] Error: {e}")
raise Exception(f"Rerank error: {e}") from e

async def terminate(self) -> None:
if self.client and not self.client.closed:
await self.client.close()
self.client = None
28 changes: 26 additions & 2 deletions dashboard/src/components/provider/AddNewProvider.vue
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<template>
<v-dialog v-model="showDialog" max-width="1100px" min-height="95%">
<v-dialog v-model="showDialog" max-width="1000px" >
<v-card :title="tm('dialogs.addProvider.title')">
<v-card-text style="overflow-y: auto;">
<v-tabs v-model="activeProviderTab" grow>
Expand Down Expand Up @@ -73,6 +73,8 @@
import { useModuleI18n } from '@/i18n/composables';
import { getProviderIcon, getProviderDescription } from '@/utils/providerUtils';

const AVAILABLE_PROVIDER_TABS = ['agent_runner', 'speech_to_text', 'text_to_speech', 'embedding', 'rerank'];

export default {
name: 'AddNewProvider',
props: {
Expand All @@ -83,6 +85,10 @@ export default {
metadata: {
type: Object,
default: () => ({})
},
currentProviderType: {
type: String,
default: 'agent_runner'
}
},
emits: ['update:show', 'select-template'],
Expand All @@ -92,7 +98,7 @@ export default {
},
data() {
return {
activeProviderTab: 'chat_completion'
activeProviderTab: 'agent_runner'
};
},
computed: {
Expand All @@ -105,7 +111,25 @@ export default {
}
},
},
watch: {
show(value) {
if (value) {
this.syncActiveProviderTab();
}
},
currentProviderType() {
if (this.showDialog) {
this.syncActiveProviderTab();
}
}
},
methods: {
syncActiveProviderTab() {
this.activeProviderTab = AVAILABLE_PROVIDER_TABS.includes(this.currentProviderType)
? this.currentProviderType
: 'agent_runner';
},

closeDialog() {
this.showDialog = false;
},
Expand Down
18 changes: 18 additions & 0 deletions dashboard/src/i18n/locales/en-US/features/config-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,24 @@
"description": "Custom rerank task description",
"hint": "Only effective for qwen3-rerank models. Recommended to write in English."
},
"nvidia_rerank_api_base": {
"description": "API Base URL"
},
"nvidia_rerank_api_key": {
"description": "API Key"
},
"nvidia_rerank_model": {
"description": "Rerank Model Name",
"hint": "Please refer to the NVIDIA Docs for the model name."
},
"nvidia_rerank_model_endpoint": {
"description": "Custom Model Endpoint",
"hint": "Custom URL suffix endpoint, defaults to /reranking."
},
"nvidia_rerank_truncate": {
"description": "Text Truncation Strategy",
"hint": "Whether to truncate the input to fit the model's maximum context length when the input text is too long."
},
"launch_model_if_not_running": {
"description": "Auto-start model if not running",
"hint": "If the model is not running in Xinference, attempt to start it automatically. Recommended to disable in production."
Expand Down
18 changes: 18 additions & 0 deletions dashboard/src/i18n/locales/ru-RU/features/config-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,24 @@
"description": "Описание задачи для Rerank",
"hint": "Эффективно только для моделей qwen3-rerank. Рекомендуется писать на английском."
},
"nvidia_rerank_api_base": {
"description": "Базовый URL API"
},
"nvidia_rerank_api_key": {
"description": "API-ключ"
},
"nvidia_rerank_model": {
"description": "Название модели Rerank",
"hint": "Укажите название модели в соответствии с документацией NVIDIA."
},
"nvidia_rerank_model_endpoint": {
"description": "Пользовательский endpoint модели",
"hint": "Пользовательский суффикс URL endpoint, по умолчанию /reranking."
},
"nvidia_rerank_truncate": {
"description": "Стратегия усечения текста",
"hint": "Определяет, следует ли усекать входной текст, если он слишком длинный и не помещается в максимальную длину контекста модели."
},
"launch_model_if_not_running": {
"description": "Автозапуск модели",
"hint": "Если модель не запущена в Xinference, попытаться запустить её автоматически. Рекомендуется отключать в продакшене."
Expand Down
18 changes: 18 additions & 0 deletions dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,24 @@
"description": "自定义排序任务类型说明",
"hint": "仅在使用 qwen3-rerank 模型时生效。建议使用英文撰写。"
},
"nvidia_rerank_api_base": {
"description": "API Base URL"
},
"nvidia_rerank_api_key": {
"description": "API Key"
},
"nvidia_rerank_model": {
"description": "重排序模型名称",
"hint": "请参照NVIDIA Docs中模型名称填写。"
},
"nvidia_rerank_model_endpoint": {
"description": "自定义模型端点",
"hint": "自定义URL末尾端点,默认为 /reranking"
},
"nvidia_rerank_truncate": {
"description": "文本截断策略",
"hint": "当输入文本过长时,是否截断输入以适应模型的最大上下文长度。"
},
"launch_model_if_not_running": {
"description": "模型未运行时自动启动",
"hint": "如果模型当前未在 Xinference 服务中运行,是否尝试自动启动它。在生产环境中建议关闭。"
Expand Down
5 changes: 4 additions & 1 deletion dashboard/src/utils/providerUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ export function getProviderIcon(type) {
'aihubmix': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/aihubmix-color.svg',
'openrouter': 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/openrouter.svg',
"tokenpony": "https://tokenpony.cn/tokenpony-web/logo.png",
"compshare": "https://compshare.cn/favicon.ico"
"compshare": "https://compshare.cn/favicon.ico",
"xinference": "https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/xinference-color.svg",
"bailian": "https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/bailian-color.svg",
"volcengine": 'https://cdn.jsdelivr.net/npm/@lobehub/icons-static-svg@latest/icons/volcengine-color.svg',
};
return icons[type] || '';
}
Expand Down
1 change: 1 addition & 0 deletions dashboard/src/views/ProviderPage.vue
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@

<!-- 添加提供商对话框 -->
<AddNewProvider v-model:show="showAddProviderDialog" :metadata="configSchema"
:current-provider-type="selectedProviderType"
@select-template="selectProviderTemplate" />

<!-- 手动添加模型对话框 -->
Expand Down
Loading