|
17 | 17 |
|
18 | 18 | from ..config import ( |
19 | 19 | DEFAULT_LITELLM_API_BASE, |
| 20 | + DEFAULT_RERANKER_ALIBABA_MODEL, |
20 | 21 | DEFAULT_RERANKER_COHERE_MODEL, |
21 | 22 | DEFAULT_RERANKER_FLASHRANK_CACHE_DIR, |
22 | 23 | DEFAULT_RERANKER_FLASHRANK_CPU_MEM_ARENA, |
|
37 | 38 | DEFAULT_RERANKER_TEI_HTTP_TIMEOUT, |
38 | 39 | DEFAULT_RERANKER_TEI_MAX_CONCURRENT, |
39 | 40 | DEFAULT_RERANKER_ZEROENTROPY_MODEL, |
| 41 | + ENV_RERANKER_ALIBABA_API_KEY, |
40 | 42 | ENV_RERANKER_COHERE_API_KEY, |
41 | 43 | ENV_RERANKER_COHERE_MODEL, |
42 | 44 | ENV_RERANKER_FLASHRANK_CACHE_DIR, |
@@ -1534,6 +1536,137 @@ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]: |
1534 | 1536 | return await loop.run_in_executor(None, self._predict_sync, pairs) |
1535 | 1537 |
|
1536 | 1538 |
|
| 1539 | +class AlibabaCloudCrossEncoder(CrossEncoderModel): |
| 1540 | + """ |
| 1541 | + Alibaba Cloud DashScope text reranking API. |
| 1542 | +
|
| 1543 | + Authentication via DashScope API key (DASHSCOPE_API_KEY or HINDSIGHT_API_RERANKER_ALIBABA_API_KEY). |
| 1544 | + See: https://help.aliyun.com/zh/model-studio/text-rerank-api |
| 1545 | + """ |
| 1546 | + |
| 1547 | + COMPATIBLE_API_URL = "https://dashscope.aliyuncs.com/compatible-api/v1/reranks" |
| 1548 | + NATIVE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank" |
| 1549 | + |
| 1550 | + # Models using the Cohere-compatible API format |
| 1551 | + _COMPATIBLE_MODELS: frozenset[str] = frozenset({"qwen3-rerank"}) |
| 1552 | + |
| 1553 | + def __init__( |
| 1554 | + self, |
| 1555 | + api_key: str, |
| 1556 | + model: str = DEFAULT_RERANKER_ALIBABA_MODEL, |
| 1557 | + timeout: float = 60.0, |
| 1558 | + ): |
| 1559 | + """ |
| 1560 | + Initialize Alibaba Cloud DashScope reranker client. |
| 1561 | +
|
| 1562 | + Args: |
| 1563 | + api_key: DashScope API key |
| 1564 | + model: Model name. Supported: qwen3-rerank (default), gte-rerank-v2 |
| 1565 | + timeout: Request timeout in seconds (default: 60.0) |
| 1566 | + """ |
| 1567 | + self.api_key = api_key |
| 1568 | + self.model = model |
| 1569 | + self.timeout = timeout |
| 1570 | + self._async_client: httpx.AsyncClient | None = None |
| 1571 | + |
| 1572 | + @property |
| 1573 | + def provider_name(self) -> str: |
| 1574 | + return "alibaba" |
| 1575 | + |
| 1576 | + async def initialize(self) -> None: |
| 1577 | + """Initialize the async HTTP client.""" |
| 1578 | + if self._async_client is not None: |
| 1579 | + return |
| 1580 | + logger.info(f"Reranker: initializing Alibaba Cloud provider with model {self.model}") |
| 1581 | + self._async_client = httpx.AsyncClient( |
| 1582 | + timeout=self.timeout, |
| 1583 | + headers={ |
| 1584 | + "Authorization": f"Bearer {self.api_key}", |
| 1585 | + "Content-Type": "application/json", |
| 1586 | + }, |
| 1587 | + ) |
| 1588 | + logger.info("Reranker: Alibaba Cloud provider initialized") |
| 1589 | + |
| 1590 | + async def predict(self, pairs: list[tuple[str, str]]) -> list[float]: |
| 1591 | + """ |
| 1592 | + Score query-document pairs using the DashScope reranking API. |
| 1593 | +
|
| 1594 | + Args: |
| 1595 | + pairs: List of (query, document) tuples to score |
| 1596 | +
|
| 1597 | + Returns: |
| 1598 | + List of relevance scores (0.0–1.0, higher = more relevant) |
| 1599 | + """ |
| 1600 | + if self._async_client is None: |
| 1601 | + raise RuntimeError("Reranker not initialized. Call initialize() first.") |
| 1602 | + |
| 1603 | + if not pairs: |
| 1604 | + return [] |
| 1605 | + |
| 1606 | + query_groups: dict[str, list[tuple[int, str]]] = {} |
| 1607 | + for idx, (query, text) in enumerate(pairs): |
| 1608 | + query_groups.setdefault(query, []).append((idx, text)) |
| 1609 | + |
| 1610 | + all_scores = [0.0] * len(pairs) |
| 1611 | + |
| 1612 | + for query, indexed_texts in query_groups.items(): |
| 1613 | + texts = [text for _, text in indexed_texts] |
| 1614 | + indices = [idx for idx, _ in indexed_texts] |
| 1615 | + |
| 1616 | + if self.model in self._COMPATIBLE_MODELS: |
| 1617 | + local_scores = await self._rerank_compatible(query, texts) |
| 1618 | + else: |
| 1619 | + local_scores = await self._rerank_native(query, texts) |
| 1620 | + |
| 1621 | + for local_idx, score in enumerate(local_scores): |
| 1622 | + all_scores[indices[local_idx]] = score |
| 1623 | + |
| 1624 | + return all_scores |
| 1625 | + |
| 1626 | + async def _rerank_compatible(self, query: str, texts: list[str]) -> list[float]: |
| 1627 | + """Call the Cohere-compatible endpoint (qwen3-rerank).""" |
| 1628 | + response = await self._async_client.post( |
| 1629 | + self.COMPATIBLE_API_URL, |
| 1630 | + json={ |
| 1631 | + "model": self.model, |
| 1632 | + "query": query, |
| 1633 | + "documents": texts, |
| 1634 | + "top_n": len(texts), |
| 1635 | + }, |
| 1636 | + ) |
| 1637 | + response.raise_for_status() |
| 1638 | + result = response.json() |
| 1639 | + |
| 1640 | + scores = [0.0] * len(texts) |
| 1641 | + for item in result.get("results", []): |
| 1642 | + scores[item["index"]] = item["relevance_score"] |
| 1643 | + return scores |
| 1644 | + |
| 1645 | + async def _rerank_native(self, query: str, texts: list[str]) -> list[float]: |
| 1646 | + """Call the DashScope native endpoint (gte-rerank-v2).""" |
| 1647 | + response = await self._async_client.post( |
| 1648 | + self.NATIVE_API_URL, |
| 1649 | + json={ |
| 1650 | + "model": self.model, |
| 1651 | + "input": { |
| 1652 | + "query": query, |
| 1653 | + "documents": texts, |
| 1654 | + }, |
| 1655 | + "parameters": { |
| 1656 | + "top_n": len(texts), |
| 1657 | + "return_documents": False, |
| 1658 | + }, |
| 1659 | + }, |
| 1660 | + ) |
| 1661 | + response.raise_for_status() |
| 1662 | + result = response.json() |
| 1663 | + |
| 1664 | + scores = [0.0] * len(texts) |
| 1665 | + for item in result.get("output", {}).get("results", []): |
| 1666 | + scores[item["index"]] = item["relevance_score"] |
| 1667 | + return scores |
| 1668 | + |
| 1669 | + |
1537 | 1670 | def create_cross_encoder_from_env() -> CrossEncoderModel: |
1538 | 1671 | """ |
1539 | 1672 | Create a CrossEncoderModel instance based on configuration. |
@@ -1648,11 +1781,21 @@ def create_cross_encoder_from_env() -> CrossEncoderModel: |
1648 | 1781 | model=config.reranker_google_model, |
1649 | 1782 | service_account_key=config.reranker_google_service_account_key, |
1650 | 1783 | ) |
| 1784 | + elif provider == "alibaba": |
| 1785 | + api_key = config.reranker_alibaba_api_key |
| 1786 | + if not api_key: |
| 1787 | + raise ValueError( |
| 1788 | + f"{ENV_RERANKER_ALIBABA_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'alibaba'" |
| 1789 | + ) |
| 1790 | + return AlibabaCloudCrossEncoder( |
| 1791 | + api_key=api_key, |
| 1792 | + model=config.reranker_alibaba_model, |
| 1793 | + ) |
1651 | 1794 | elif provider == "rrf": |
1652 | 1795 | return RRFPassthroughCrossEncoder() |
1653 | 1796 | elif provider == "jina-mlx": |
1654 | 1797 | return JinaMLXCrossEncoder() |
1655 | 1798 | else: |
1656 | 1799 | raise ValueError( |
1657 | | - f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'siliconflow', 'google', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'" |
| 1800 | + f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'siliconflow', 'alibaba', 'google', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'" |
1658 | 1801 | ) |
0 commit comments