Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions hindsight-api-slim/hindsight_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
ENV_RERANKER_SILICONFLOW_MODEL = "HINDSIGHT_API_RERANKER_SILICONFLOW_MODEL"
ENV_RERANKER_SILICONFLOW_BASE_URL = "HINDSIGHT_API_RERANKER_SILICONFLOW_BASE_URL"

# Alibaba Cloud DashScope configuration (reranker only)
ENV_RERANKER_ALIBABA_API_KEY = "HINDSIGHT_API_RERANKER_ALIBABA_API_KEY"
ENV_RERANKER_ALIBABA_MODEL = "HINDSIGHT_API_RERANKER_ALIBABA_MODEL"

# Google Discovery Engine reranker configuration
ENV_RERANKER_GOOGLE_MODEL = "HINDSIGHT_API_RERANKER_GOOGLE_MODEL"
ENV_RERANKER_GOOGLE_PROJECT_ID = "HINDSIGHT_API_RERANKER_GOOGLE_PROJECT_ID"
Expand Down Expand Up @@ -535,6 +539,8 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
DEFAULT_RERANKER_SILICONFLOW_MODEL = "BAAI/bge-reranker-v2-m3"
DEFAULT_RERANKER_SILICONFLOW_BASE_URL = "https://api.siliconflow.cn/v1"

DEFAULT_RERANKER_ALIBABA_MODEL = "qwen3-rerank"

DEFAULT_RERANKER_GOOGLE_MODEL = "semantic-ranker-default-004"

# Vector extension (pgvector, vchord, pgvectorscale, or AlloyDB ScaNN)
Expand Down Expand Up @@ -973,6 +979,8 @@ class HindsightConfig:
reranker_siliconflow_api_key: str | None
reranker_siliconflow_model: str
reranker_siliconflow_base_url: str
reranker_alibaba_api_key: str | None
reranker_alibaba_model: str
reranker_google_model: str
reranker_google_project_id: str | None
reranker_google_service_account_key: str | None
Expand Down Expand Up @@ -1600,6 +1608,9 @@ def from_env(cls) -> "HindsightConfig":
reranker_siliconflow_base_url=os.getenv(
ENV_RERANKER_SILICONFLOW_BASE_URL, DEFAULT_RERANKER_SILICONFLOW_BASE_URL
),
# Alibaba Cloud DashScope reranker
reranker_alibaba_api_key=os.getenv(ENV_RERANKER_ALIBABA_API_KEY),
reranker_alibaba_model=os.getenv(ENV_RERANKER_ALIBABA_MODEL, DEFAULT_RERANKER_ALIBABA_MODEL),
# Google Discovery Engine reranker (with fallback to LLM Vertex AI keys)
reranker_google_model=os.getenv(ENV_RERANKER_GOOGLE_MODEL, DEFAULT_RERANKER_GOOGLE_MODEL),
reranker_google_project_id=os.getenv(ENV_RERANKER_GOOGLE_PROJECT_ID)
Expand Down
145 changes: 144 additions & 1 deletion hindsight-api-slim/hindsight_api/engine/cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..config import (
DEFAULT_LITELLM_API_BASE,
DEFAULT_RERANKER_ALIBABA_MODEL,
DEFAULT_RERANKER_COHERE_MODEL,
DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
DEFAULT_RERANKER_FLASHRANK_CPU_MEM_ARENA,
Expand All @@ -37,6 +38,7 @@
DEFAULT_RERANKER_TEI_HTTP_TIMEOUT,
DEFAULT_RERANKER_TEI_MAX_CONCURRENT,
DEFAULT_RERANKER_ZEROENTROPY_MODEL,
ENV_RERANKER_ALIBABA_API_KEY,
ENV_RERANKER_COHERE_API_KEY,
ENV_RERANKER_COHERE_MODEL,
ENV_RERANKER_FLASHRANK_CACHE_DIR,
Expand Down Expand Up @@ -1534,6 +1536,137 @@ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
return await loop.run_in_executor(None, self._predict_sync, pairs)


class AlibabaCloudCrossEncoder(CrossEncoderModel):
"""
Alibaba Cloud DashScope text reranking API.

Authentication via DashScope API key (DASHSCOPE_API_KEY or HINDSIGHT_API_RERANKER_ALIBABA_API_KEY).
See: https://help.aliyun.com/zh/model-studio/text-rerank-api
"""

COMPATIBLE_API_URL = "https://dashscope.aliyuncs.com/compatible-api/v1/reranks"
NATIVE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank"

# Models using the Cohere-compatible API format
_COMPATIBLE_MODELS: frozenset[str] = frozenset({"qwen3-rerank"})

def __init__(
self,
api_key: str,
model: str = DEFAULT_RERANKER_ALIBABA_MODEL,
timeout: float = 60.0,
):
"""
Initialize Alibaba Cloud DashScope reranker client.

Args:
api_key: DashScope API key
model: Model name. Supported: qwen3-rerank (default), gte-rerank-v2
timeout: Request timeout in seconds (default: 60.0)
"""
self.api_key = api_key
self.model = model
self.timeout = timeout
self._async_client: httpx.AsyncClient | None = None

@property
def provider_name(self) -> str:
return "alibaba"

async def initialize(self) -> None:
"""Initialize the async HTTP client."""
if self._async_client is not None:
return
logger.info(f"Reranker: initializing Alibaba Cloud provider with model {self.model}")
self._async_client = httpx.AsyncClient(
timeout=self.timeout,
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
logger.info("Reranker: Alibaba Cloud provider initialized")

async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
"""
Score query-document pairs using the DashScope reranking API.

Args:
pairs: List of (query, document) tuples to score

Returns:
List of relevance scores (0.0–1.0, higher = more relevant)
"""
if self._async_client is None:
raise RuntimeError("Reranker not initialized. Call initialize() first.")

if not pairs:
return []

query_groups: dict[str, list[tuple[int, str]]] = {}
for idx, (query, text) in enumerate(pairs):
query_groups.setdefault(query, []).append((idx, text))

all_scores = [0.0] * len(pairs)

for query, indexed_texts in query_groups.items():
texts = [text for _, text in indexed_texts]
indices = [idx for idx, _ in indexed_texts]

if self.model in self._COMPATIBLE_MODELS:
local_scores = await self._rerank_compatible(query, texts)
else:
local_scores = await self._rerank_native(query, texts)

for local_idx, score in enumerate(local_scores):
all_scores[indices[local_idx]] = score

return all_scores

async def _rerank_compatible(self, query: str, texts: list[str]) -> list[float]:
"""Call the Cohere-compatible endpoint (qwen3-rerank)."""
response = await self._async_client.post(
self.COMPATIBLE_API_URL,
json={
"model": self.model,
"query": query,
"documents": texts,
"top_n": len(texts),
},
)
response.raise_for_status()
result = response.json()

scores = [0.0] * len(texts)
for item in result.get("results", []):
scores[item["index"]] = item["relevance_score"]
return scores

async def _rerank_native(self, query: str, texts: list[str]) -> list[float]:
"""Call the DashScope native endpoint (gte-rerank-v2)."""
response = await self._async_client.post(
self.NATIVE_API_URL,
json={
"model": self.model,
"input": {
"query": query,
"documents": texts,
},
"parameters": {
"top_n": len(texts),
"return_documents": False,
},
},
)
response.raise_for_status()
result = response.json()

scores = [0.0] * len(texts)
for item in result.get("output", {}).get("results", []):
scores[item["index"]] = item["relevance_score"]
return scores


def create_cross_encoder_from_env() -> CrossEncoderModel:
"""
Create a CrossEncoderModel instance based on configuration.
Expand Down Expand Up @@ -1648,11 +1781,21 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
model=config.reranker_google_model,
service_account_key=config.reranker_google_service_account_key,
)
elif provider == "alibaba":
api_key = config.reranker_alibaba_api_key
if not api_key:
raise ValueError(
f"{ENV_RERANKER_ALIBABA_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'alibaba'"
)
return AlibabaCloudCrossEncoder(
api_key=api_key,
model=config.reranker_alibaba_model,
)
elif provider == "rrf":
return RRFPassthroughCrossEncoder()
elif provider == "jina-mlx":
return JinaMLXCrossEncoder()
else:
raise ValueError(
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'siliconflow', 'google', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'"
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'zeroentropy', 'siliconflow', 'alibaba', 'google', 'flashrank', 'litellm', 'litellm-sdk', 'rrf', 'jina-mlx'"
)