Skip to content

Commit 25ecaa0

Browse files
committed
feat(reranker): support alibaba qwen3-rerank
1 parent 8c6be6d commit 25ecaa0

2 files changed

Lines changed: 155 additions & 1 deletion

File tree

hindsight-api-slim/hindsight_api/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,10 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
268268
ENV_RERANKER_SILICONFLOW_MODEL = "HINDSIGHT_API_RERANKER_SILICONFLOW_MODEL"
269269
ENV_RERANKER_SILICONFLOW_BASE_URL = "HINDSIGHT_API_RERANKER_SILICONFLOW_BASE_URL"
270270

271+
# Alibaba Cloud DashScope configuration (reranker only)
272+
ENV_RERANKER_ALIBABA_API_KEY = "HINDSIGHT_API_RERANKER_ALIBABA_API_KEY"
273+
ENV_RERANKER_ALIBABA_MODEL = "HINDSIGHT_API_RERANKER_ALIBABA_MODEL"
274+
271275
# Google Discovery Engine reranker configuration
272276
ENV_RERANKER_GOOGLE_MODEL = "HINDSIGHT_API_RERANKER_GOOGLE_MODEL"
273277
ENV_RERANKER_GOOGLE_PROJECT_ID = "HINDSIGHT_API_RERANKER_GOOGLE_PROJECT_ID"
@@ -535,6 +539,8 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
535539
DEFAULT_RERANKER_SILICONFLOW_MODEL = "BAAI/bge-reranker-v2-m3"
536540
DEFAULT_RERANKER_SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"
537541

542+
DEFAULT_RERANKER_ALIBABA_MODEL = "qwen3-rerank"
543+
538544
DEFAULT_RERANKER_GOOGLE_MODEL = "semantic-ranker-default-004"
539545

540546
# Vector extension (pgvector, vchord, pgvectorscale, or AlloyDB ScaNN)
@@ -973,6 +979,8 @@ class HindsightConfig:
973979
reranker_siliconflow_api_key: str | None
974980
reranker_siliconflow_model: str
975981
reranker_siliconflow_base_url: str
982+
reranker_alibaba_api_key: str | None
983+
reranker_alibaba_model: str
976984
reranker_google_model: str
977985
reranker_google_project_id: str | None
978986
reranker_google_service_account_key: str | None
@@ -1600,6 +1608,9 @@ def from_env(cls) -> "HindsightConfig":
16001608
reranker_siliconflow_base_url=os.getenv(
16011609
ENV_RERANKER_SILICONFLOW_BASE_URL, DEFAULT_RERANKER_SILICONFLOW_BASE_URL
16021610
),
1611+
# Alibaba Cloud DashScope reranker
1612+
reranker_alibaba_api_key=os.getenv(ENV_RERANKER_ALIBABA_API_KEY),
1613+
reranker_alibaba_model=os.getenv(ENV_RERANKER_ALIBABA_MODEL, DEFAULT_RERANKER_ALIBABA_MODEL),
16031614
# Google Discovery Engine reranker (with fallback to LLM Vertex AI keys)
16041615
reranker_google_model=os.getenv(ENV_RERANKER_GOOGLE_MODEL, DEFAULT_RERANKER_GOOGLE_MODEL),
16051616
reranker_google_project_id=os.getenv(ENV_RERANKER_GOOGLE_PROJECT_ID)

hindsight-api-slim/hindsight_api/engine/cross_encoder.py

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ..config import (
1919
DEFAULT_LITELLM_API_BASE,
20+
DEFAULT_RERANKER_ALIBABA_MODEL,
2021
DEFAULT_RERANKER_COHERE_MODEL,
2122
DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
2223
DEFAULT_RERANKER_FLASHRANK_CPU_MEM_ARENA,
@@ -37,6 +38,7 @@
3738
DEFAULT_RERANKER_TEI_HTTP_TIMEOUT,
3839
DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
3940
DEFAULT_RERANKER_ZEROENTROPY_MODEL,
41+
ENV_RERANKER_ALIBABA_API_KEY,
4042
ENV_RERANKER_COHERE_API_KEY,
4143
ENV_RERANKER_COHERE_MODEL,
4244
ENV_RERANKER_FLASHRANK_CACHE_DIR,
@@ -1534,6 +1536,137 @@ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
15341536
return await loop.run_in_executor(None, self._predict_sync, pairs)
15351537

15361538

1539+
class AlibabaCloudCrossEncoder(CrossEncoderModel):
1540+
"""
1541+
Alibaba Cloud DashScope text reranking API.
1542+
1543+
Authentication via DashScope API key (DASHSCOPE_API_KEY or HINDSIGHT_API_RERANKER_ALIBABA_API_KEY).
1544+
See: https://help.aliyun.com/zh/model-studio/text-rerank-api
1545+
"""
1546+
1547+
COMPATIBLE_API_URL = "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
1548+
NATIVE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"
1549+
1550+
# Models using the Cohere-compatible API format
1551+
_COMPATIBLE_MODELS: frozenset[str] = frozenset({"qwen3-rerank"})
1552+
1553+
def __init__(
1554+
self,
1555+
api_key: str,
1556+
model: str = DEFAULT_RERANKER_ALIBABA_MODEL,
1557+
timeout: float = 60.0,
1558+
):
1559+
"""
1560+
Initialize Alibaba Cloud DashScope reranker client.
1561+
1562+
Args:
1563+
api_key: DashScope API key
1564+
model: Model name. Supported: qwen3-rerank (default), gte-rerank-v2
1565+
timeout: Request timeout in seconds (default: 60.0)
1566+
"""
1567+
self.api_key = api_key
1568+
self.model = model
1569+
self.timeout = timeout
1570+
self._async_client: httpx.AsyncClient | None = None
1571+
1572+
@property
1573+
def provider_name(self) -> str:
1574+
return "alibaba"
1575+
1576+
async def initialize(self) -> None:
1577+
"""Initialize the async HTTP client."""
1578+
if self._async_client is not None:
1579+
return
1580+
logger.info(f"Reranker: initializing Alibaba Cloud provider with model {self.model}")
1581+
self._async_client = httpx.AsyncClient(
1582+
timeout=self.timeout,
1583+
headers={
1584+
"Authorization": f"Bearer {self.api_key}",
1585+
"Content-Type": "application/json",
1586+
},
1587+
)
1588+
logger.info("Reranker: Alibaba Cloud provider initialized")
1589+
1590+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
1591+
"""
1592+
Score query-document pairs using the DashScope reranking API.
1593+
1594+
Args:
1595+
pairs: List of (query, document) tuples to score
1596+
1597+
Returns:
1598+
List of relevance scores (0.0–1.0, higher = more relevant)
1599+
"""
1600+
if self._async_client is None:
1601+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
1602+
1603+
if not pairs:
1604+
return []
1605+
1606+
query_groups: dict[str, list[tuple[int, str]]] = {}
1607+
for idx, (query, text) in enumerate(pairs):
1608+
query_groups.setdefault(query, []).append((idx, text))
1609+
1610+
all_scores = [0.0] * len(pairs)
1611+
1612+
for query, indexed_texts in query_groups.items():
1613+
texts = [text for _, text in indexed_texts]
1614+
indices = [idx for idx, _ in indexed_texts]
1615+
1616+
if self.model in self._COMPATIBLE_MODELS:
1617+
local_scores = await self._rerank_compatible(query, texts)
1618+
else:
1619+
local_scores = await self._rerank_native(query, texts)
1620+
1621+
for local_idx, score in enumerate(local_scores):
1622+
all_scores[indices[local_idx]] = score
1623+
1624+
return all_scores
1625+
1626+
async def _rerank_compatible(self, query: str, texts: list[str]) -> list[float]:
1627+
"""Call the Cohere-compatible endpoint (qwen3-rerank)."""
1628+
response = await self._async_client.post(
1629+
self.COMPATIBLE_API_URL,
1630+
json={
1631+
"model": self.model,
1632+
"query": query,
1633+
"documents": texts,
1634+
"top_n": len(texts),
1635+
},
1636+
)
1637+
response.raise_for_status()
1638+
result = response.json()
1639+
1640+
scores = [0.0] * len(texts)
1641+
for item in result.get("results", []):
1642+
scores[item["index"]] = item["relevance_score"]
1643+
return scores
1644+
1645+
async def _rerank_native(self, query: str, texts: list[str]) -> list[float]:
1646+
"""Call the DashScope native endpoint (gte-rerank-v2)."""
1647+
response = await self._async_client.post(
1648+
self.NATIVE_API_URL,
1649+
json={
1650+
"model": self.model,
1651+
"input": {
1652+
"query": query,
1653+
"documents": texts,
1654+
},
1655+
"parameters": {
1656+
"top_n": len(texts),
1657+
"return_documents": False,
1658+
},
1659+
},
1660+
)
1661+
response.raise_for_status()
1662+
result = response.json()
1663+
1664+
scores = [0.0] * len(texts)
1665+
for item in result.get("output", {}).get("results", []):
1666+
scores[item["index"]] = item["relevance_score"]
1667+
return scores
1668+
1669+
15371670
def create_cross_encoder_from_env() -> CrossEncoderModel:
15381671
"""
15391672
Create a CrossEncoderModel instance based on configuration.
@@ -1648,11 +1781,21 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
16481781
model=config.reranker_google_model,
16491782
service_account_key=config.reranker_google_service_account_key,
16501783
)
1784+
elif provider == "alibaba":
1785+
api_key = config.reranker_alibaba_api_key
1786+
if not api_key:
1787+
raise ValueError(
1788+
f"{ENV_RERANKER_ALIBABA_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'alibaba'"
1789+
)
1790+
return AlibabaCloudCrossEncoder(
1791+
api_key=api_key,
1792+
model=config.reranker_alibaba_model,
1793+
)
16511794
elif provider == "rrf":
16521795
return RRFPassthroughCrossEncoder()
16531796
elif provider == "jina-mlx":
16541797
return JinaMLXCrossEncoder()
16551798
else:
16561799
raise ValueError(
1657-
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'siliconflow', 'google', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'"
1800+
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'siliconflow', 'alibaba', 'google', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'"
16581801
)

0 commit comments

Comments
 (0)