-
-
Notifications
You must be signed in to change notification settings - Fork 2k
feat: add NVIDIA rerank provider support #7227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6799c21
8b9960d
e1088af
6a4c26d
5cb7e01
e0501c2
96da8f9
76470e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| import aiohttp | ||
|
|
||
| from astrbot import logger | ||
|
|
||
| from ..entities import ProviderType, RerankResult | ||
| from ..provider import RerankProvider | ||
| from ..register import register_provider_adapter | ||
|
|
||
|
|
||
| @register_provider_adapter( | ||
| "nvidia_rerank", "NVIDIA Rerank 适配器", provider_type=ProviderType.RERANK | ||
| ) | ||
| class NvidiaRerankProvider(RerankProvider): | ||
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: | ||
| super().__init__(provider_config, provider_settings) | ||
| self.api_key = provider_config.get("nvidia_rerank_api_key", "") | ||
| self.base_url = provider_config.get( | ||
| "nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval" | ||
| ).rstrip("/") | ||
| self.timeout = provider_config.get("timeout", 20) | ||
| self.model = provider_config.get( | ||
| "nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1" | ||
| ) | ||
| self.model_endpoint = provider_config.get( | ||
| "nvidia_rerank_model_endpoint", "/reranking" | ||
| ) | ||
| self.truncate = provider_config.get("nvidia_rerank_truncate", "") | ||
|
|
||
| self.client = None | ||
| self.set_model(self.model) | ||
|
|
||
| async def _get_client(self): | ||
| if self.client is None or self.client.closed: | ||
| headers = { | ||
| "Authorization": f"Bearer {self.api_key}", | ||
| "Content-Type": "application/json", | ||
| "Accept": "application/json", | ||
| } | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) | ||
| ) | ||
| return self.client | ||
|
|
||
| def _get_endpoint(self) -> str: | ||
| """ | ||
| 构建完整API URL。 | ||
|
|
||
| 根据 Nvidia Rerank API 文档来看,当前URL存在不同模型格式不一致的问题。 | ||
| 这里针对模型名做一个基础判断用以适配,后续要等Nvidia统一API格式后再做调整。 | ||
|
|
||
| 例: | ||
| 模型: nv-rerank-qa-mistral-4b:1 | ||
| URL: .../v1/retrieval/nvidia/reranking | ||
|
|
||
| 模型: nvidia/llama-nemotron-rerank-1b-v2 | ||
| URL: .../v1/retrieval/nvidia/llama-nemotron-rerank-1b-v2/reranking | ||
| """ | ||
|
|
||
| model_path = "nvidia" | ||
| logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}") | ||
| if "/" in self.model: | ||
| """遵循NVIDIA API的URL规则,替换模型名中特殊字符""" | ||
| model_path = self.model.strip("/").replace(".", "_") | ||
| endpoint = self.model_endpoint.lstrip("/") | ||
| return f"{self.base_url}/{model_path}/{endpoint}" | ||
|
|
||
| def _build_payload(self, query: str, documents: list[str]) -> dict: | ||
| """构建请求载荷""" | ||
| payload = { | ||
| "model": self.model, | ||
| "query": {"text": query}, | ||
| "passages": [{"text": doc} for doc in documents], | ||
| } | ||
| if self.truncate: | ||
| payload["truncate"] = self.truncate | ||
| return payload | ||
|
|
||
| def _parse_results( | ||
| self, response_data: dict, top_n: int | None | ||
| ) -> list[RerankResult]: | ||
| """解析响应数据""" | ||
| results = response_data.get("rankings", []) | ||
| if not results: | ||
| logger.warning(f"[NVIDIA Rerank] Empty response: {response_data}") | ||
| return [] | ||
|
|
||
| rerank_results = [] | ||
| for idx, item in enumerate(results): | ||
| try: | ||
| index = item.get("index", idx) | ||
| score = item.get("relevance_score", item.get("logit", 0.0)) | ||
| rerank_results.append( | ||
| RerankResult(index=index, relevance_score=float(score)) | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}" | ||
| ) | ||
|
|
||
| rerank_results.sort(key=lambda x: x.relevance_score, reverse=True) | ||
|
|
||
| if top_n is not None and top_n > 0: | ||
| return rerank_results[:top_n] | ||
| return rerank_results | ||
|
|
||
| def _log_usage(self, data: dict) -> None: | ||
| usage = data.get("usage", {}) | ||
| total_tokens = usage.get("total_tokens", 0) | ||
| if total_tokens > 0: | ||
| logger.debug(f"[NVIDIA Rerank] Token Usage: {total_tokens}") | ||
|
|
||
| async def rerank( | ||
| self, | ||
| query: str, | ||
| documents: list[str], | ||
| top_n: int | None = None, | ||
| ) -> list[RerankResult]: | ||
| client = await self._get_client() | ||
| if not client or client.closed: | ||
| logger.error("[NVIDIA Rerank] Client session not initialized or closed") | ||
| return [] | ||
|
|
||
| if not documents or not query.strip(): | ||
| logger.warning( | ||
| "[NVIDIA Rerank] Input data is invalid, query or documents are empty" | ||
| ) | ||
| return [] | ||
|
|
||
| try: | ||
| payload = self._build_payload(query, documents) | ||
| request_url = self._get_endpoint() | ||
|
|
||
| async with client.post(request_url, json=payload) as response: | ||
| if response.status != 200: | ||
| try: | ||
| response_data = await response.json() | ||
| error_detail = response_data.get( | ||
| "detail", response_data.get("message", "Unknown Error") | ||
| ) | ||
|
|
||
| except Exception: | ||
| error_detail = await response.text() | ||
| response_data = {"message": error_detail} | ||
|
|
||
| logger.error(f"[NVIDIA Rerank] API Error Response: {response_data}") | ||
| raise Exception(f"HTTP {response.status} - {error_detail}") | ||
|
Comment on lines
+137
to
+146
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The error handling logic here might fail if the response body is consumed twice. When |
||
|
|
||
| response_data = await response.json() | ||
| logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") | ||
| results = self._parse_results(response_data, top_n) | ||
| self._log_usage(response_data) | ||
| return results | ||
|
|
||
| except aiohttp.ClientError as e: | ||
| logger.error(f"[NVIDIA Rerank] Network error: {e}") | ||
| raise Exception(f"Network error: {e}") from e | ||
| except Exception as e: | ||
| logger.error(f"[NVIDIA Rerank] Error: {e}") | ||
| raise Exception(f"Rerank error: {e}") from e | ||
|
|
||
| async def terminate(self) -> None: | ||
| if self.client and not self.client.closed: | ||
| await self.client.close() | ||
| self.client = None | ||
Uh oh!
There was an error while loading. Please reload this page.