-
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
feat: add NVIDIA rerank provider support #7227
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+294
−3
Merged
Changes from 6 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
6799c21
feat: add Rerank API support for NVIDIA NIM
WenqiOfficial 8b9960d
chore: format code
WenqiOfficial e1088af
fix: replace illegal characters
WenqiOfficial 6a4c26d
fix: refactor client initialization method
WenqiOfficial 5cb7e01
fix: enhance response parsing
WenqiOfficial e0501c2
docs: add comment for model_path process
WenqiOfficial 96da8f9
docs: add russia translation
Soulter 76470e2
feat: update AddNewProvider component to support current provider typ…
Soulter File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| import aiohttp | ||
|
|
||
| from astrbot import logger | ||
|
|
||
| from ..entities import ProviderType, RerankResult | ||
| from ..provider import RerankProvider | ||
| from ..register import register_provider_adapter | ||
|
|
||
|
|
||
| @register_provider_adapter( | ||
| "nvidia_rerank", "NVIDIA Rerank 适配器", provider_type=ProviderType.RERANK | ||
| ) | ||
| class NvidiaRerankProvider(RerankProvider): | ||
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: | ||
| super().__init__(provider_config, provider_settings) | ||
| self.api_key = provider_config.get("nvidia_rerank_api_key", "") | ||
| self.base_url = provider_config.get( | ||
| "nvidia_rerank_api_base", "https://ai.api.nvidia.com/v1/retrieval" | ||
| ).rstrip("/") | ||
| self.timeout = provider_config.get("timeout", 20) | ||
| self.model = provider_config.get( | ||
| "nvidia_rerank_model", "nv-rerank-qa-mistral-4b:1" | ||
| ) | ||
| self.model_endpoint = provider_config.get( | ||
| "nvidia_rerank_model_endpoint", "/reranking" | ||
| ) | ||
| self.truncate = provider_config.get("nvidia_rerank_truncate", "") | ||
|
|
||
| self.client = None | ||
| self.set_model(self.model) | ||
|
|
||
| async def _get_client(self): | ||
| if self.client is None or self.client.closed: | ||
| headers = { | ||
| "Authorization": f"Bearer {self.api_key}", | ||
| "Content-Type": "application/json", | ||
| "Accept": "application/json", | ||
| } | ||
| self.client = aiohttp.ClientSession( | ||
| headers=headers, timeout=aiohttp.ClientTimeout(total=self.timeout) | ||
| ) | ||
| return self.client | ||
|
|
||
| def _get_endpoint(self) -> str: | ||
| """ | ||
| 构建完整API URL。 | ||
|
|
||
| 根据 Nvidia Rerank API 文档来看,当前URL存在不同模型格式不一致的问题。 | ||
| 这里针对模型名做一个基础判断用以适配,后续要等Nvidia统一API格式后再做调整。 | ||
|
|
||
| 例: | ||
| 模型: nv-rerank-qa-mistral-4b:1 | ||
| URL: .../v1/retrieval/nvidia/reranking | ||
|
|
||
| 模型: nvidia/llama-nemotron-rerank-1b-v2 | ||
| URL: .../v1/retrieval/nvidia/llama-nemotron-rerank-1b-v2/reranking | ||
| """ | ||
|
|
||
| model_path = "nvidia" | ||
| logger.debug(f"[NVIDIA Rerank] Building endpoint for model: {self.model}") | ||
| if "/" in self.model: | ||
| """遵循NVIDIA API的URL规则,替换模型名中特殊字符""" | ||
| model_path = self.model.strip("/").replace(".", "_") | ||
| endpoint = self.model_endpoint.lstrip("/") | ||
| return f"{self.base_url}/{model_path}/{endpoint}" | ||
|
|
||
| def _build_payload(self, query: str, documents: list[str]) -> dict: | ||
| """构建请求载荷""" | ||
| payload = { | ||
| "model": self.model, | ||
| "query": {"text": query}, | ||
| "passages": [{"text": doc} for doc in documents], | ||
| } | ||
| if self.truncate: | ||
| payload["truncate"] = self.truncate | ||
| return payload | ||
|
|
||
| def _parse_results( | ||
| self, response_data: dict, top_n: int | None | ||
| ) -> list[RerankResult]: | ||
| """解析响应数据""" | ||
| results = response_data.get("rankings", []) | ||
| if not results: | ||
| logger.warning(f"[NVIDIA Rerank] Empty response: {response_data}") | ||
| return [] | ||
|
|
||
| rerank_results = [] | ||
| for idx, item in enumerate(results): | ||
| try: | ||
| index = item.get("index", idx) | ||
| score = item.get("relevance_score", item.get("logit", 0.0)) | ||
| rerank_results.append( | ||
| RerankResult(index=index, relevance_score=float(score)) | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| f"[NVIDIA Rerank] Result parsing error: {e}, Data={item}" | ||
| ) | ||
|
|
||
| rerank_results.sort(key=lambda x: x.relevance_score, reverse=True) | ||
|
|
||
| if top_n is not None and top_n > 0: | ||
| return rerank_results[:top_n] | ||
| return rerank_results | ||
|
|
||
| def _log_usage(self, data: dict) -> None: | ||
| usage = data.get("usage", {}) | ||
| total_tokens = usage.get("total_tokens", 0) | ||
| if total_tokens > 0: | ||
| logger.debug(f"[NVIDIA Rerank] Token Usage: {total_tokens}") | ||
|
|
||
| async def rerank( | ||
| self, | ||
| query: str, | ||
| documents: list[str], | ||
| top_n: int | None = None, | ||
| ) -> list[RerankResult]: | ||
| client = await self._get_client() | ||
| if not client or client.closed: | ||
| logger.error("[NVIDIA Rerank] Client session not initialized or closed") | ||
| return [] | ||
|
|
||
| if not documents or not query.strip(): | ||
| logger.warning( | ||
| "[NVIDIA Rerank] Input data is invalid, query or documents are empty" | ||
| ) | ||
| return [] | ||
|
|
||
| try: | ||
| payload = self._build_payload(query, documents) | ||
| request_url = self._get_endpoint() | ||
|
|
||
| async with client.post(request_url, json=payload) as response: | ||
| if response.status != 200: | ||
| try: | ||
| response_data = await response.json() | ||
| error_detail = response_data.get( | ||
| "detail", response_data.get("message", "Unknown Error") | ||
| ) | ||
|
|
||
| except Exception: | ||
| error_detail = await response.text() | ||
| response_data = {"message": error_detail} | ||
|
|
||
| logger.error(f"[NVIDIA Rerank] API Error Response: {response_data}") | ||
| raise Exception(f"HTTP {response.status} - {error_detail}") | ||
|
WenqiOfficial marked this conversation as resolved.
|
||
|
|
||
| response_data = await response.json() | ||
| logger.debug(f"[NVIDIA Rerank] API Response: {response_data}") | ||
| results = self._parse_results(response_data, top_n) | ||
| self._log_usage(response_data) | ||
| return results | ||
|
|
||
| except aiohttp.ClientError as e: | ||
| logger.error(f"[NVIDIA Rerank] Network error: {e}") | ||
| raise Exception(f"Network error: {e}") from e | ||
| except Exception as e: | ||
| logger.error(f"[NVIDIA Rerank] Error: {e}") | ||
| raise Exception(f"Rerank error: {e}") from e | ||
|
|
||
| async def terminate(self) -> None: | ||
| if self.client and not self.client.closed: | ||
| await self.client.close() | ||
| self.client = None | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.