Skip to content

Commit e408b7e

Browse files
authored
feat: support litellm-sdk as reranker and embeddings (#357)
* feat: support litellm-sdk for reranker endpoint * feat: support litellm-sdk for reranker endpoint * fix: make litellm SDK cohere test fixture async function-scoped * fix: store litellm module reference during initialization to avoid import issues * feat: add LiteLLM SDK embeddings support - Add LiteLLMSDKEmbeddings class for direct API access without proxy - Support multiple providers: Cohere, OpenAI, Together AI, HuggingFace, Voyage AI - Automatic dimension detection via test embedding - Provider-specific API key mapping - Batch processing support (configurable batch size) - Comprehensive test coverage (17 unit tests) - Update documentation with configuration examples Implements embeddings in same PR as reranker per user request * fix: correct config mocking in embeddings factory tests - Mock get_config() from its source module (hindsight_api.config) - Fixes factory tests that were returning LocalSTEmbeddings instead of LiteLLMSDKEmbeddings - All 17 unit tests now passing * fix: skip Cohere integration tests when API key is invalid - Catch initialization errors and skip tests instead of failing - Prevents CI failures when COHERE_API_KEY is set but invalid - Integration tests now properly skip when authentication fails * fix: skip Cohere reranker integration tests when API key is invalid - Add same error handling as embeddings tests - Prevents CI failures when COHERE_API_KEY is set but invalid - Tests now properly skip when authentication fails * Revert "fix: skip Cohere reranker integration tests when API key is invalid" This reverts commit 655daca. * Revert "fix: skip Cohere integration tests when API key is invalid" This reverts commit 5d00548. * fix: pass API key directly to litellm SDK functions - Add api_key parameter to arerank(), rerank(), aembedding(), and embedding() calls - Prevents authentication issues in multi-process environments (pytest-xdist) - More reliable than relying solely on environment variables - Update test assertions to expect api_key parameter * feat: pass api_base parameter to litellm SDK calls and remove hasattr check * fix: raise errors instead of silently returning 0.0 scores * refactor: pass API keys directly in kwargs instead of setting env vars
1 parent d871c30 commit e408b7e

9 files changed

Lines changed: 1225 additions & 15 deletions

File tree

hindsight-api/hindsight_api/config.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,14 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
189189
ENV_RERANKER_LITELLM_API_KEY = "HINDSIGHT_API_RERANKER_LITELLM_API_KEY"
190190
ENV_RERANKER_LITELLM_MODEL = "HINDSIGHT_API_RERANKER_LITELLM_MODEL"
191191

192+
# LiteLLM SDK configuration (direct API access, no proxy needed)
193+
ENV_EMBEDDINGS_LITELLM_SDK_API_KEY = "HINDSIGHT_API_EMBEDDINGS_LITELLM_SDK_API_KEY"
194+
ENV_EMBEDDINGS_LITELLM_SDK_MODEL = "HINDSIGHT_API_EMBEDDINGS_LITELLM_SDK_MODEL"
195+
ENV_EMBEDDINGS_LITELLM_SDK_API_BASE = "HINDSIGHT_API_EMBEDDINGS_LITELLM_SDK_API_BASE"
196+
ENV_RERANKER_LITELLM_SDK_API_KEY = "HINDSIGHT_API_RERANKER_LITELLM_SDK_API_KEY"
197+
ENV_RERANKER_LITELLM_SDK_MODEL = "HINDSIGHT_API_RERANKER_LITELLM_SDK_MODEL"
198+
ENV_RERANKER_LITELLM_SDK_API_BASE = "HINDSIGHT_API_RERANKER_LITELLM_SDK_API_BASE"
199+
192200
# Deprecated: Legacy shared LiteLLM config (for backward compatibility)
193201
ENV_LITELLM_API_BASE = "HINDSIGHT_API_LITELLM_API_BASE"
194202
ENV_LITELLM_API_KEY = "HINDSIGHT_API_LITELLM_API_KEY"
@@ -337,6 +345,10 @@ def normalize_config_dict(config: dict[str, Any]) -> dict[str, Any]:
337345
DEFAULT_EMBEDDINGS_LITELLM_MODEL = "text-embedding-3-small"
338346
DEFAULT_RERANKER_LITELLM_MODEL = "cohere/rerank-english-v3.0"
339347

348+
# LiteLLM SDK defaults
349+
DEFAULT_EMBEDDINGS_LITELLM_SDK_MODEL = "cohere/embed-english-v3.0"
350+
DEFAULT_RERANKER_LITELLM_SDK_MODEL = "cohere/rerank-english-v3.0"
351+
340352
DEFAULT_HOST = "0.0.0.0"
341353
DEFAULT_PORT = 8888
342354
DEFAULT_BASE_PATH = "" # Empty string = root path
@@ -532,6 +544,9 @@ class HindsightConfig:
532544
embeddings_litellm_api_base: str
533545
embeddings_litellm_api_key: str | None
534546
embeddings_litellm_model: str
547+
embeddings_litellm_sdk_api_key: str | None
548+
embeddings_litellm_sdk_model: str
549+
embeddings_litellm_sdk_api_base: str | None
535550

536551
# Reranker
537552
reranker_provider: str
@@ -549,6 +564,9 @@ class HindsightConfig:
549564
reranker_litellm_api_base: str
550565
reranker_litellm_api_key: str | None
551566
reranker_litellm_model: str
567+
reranker_litellm_sdk_api_key: str | None
568+
reranker_litellm_sdk_model: str
569+
reranker_litellm_sdk_api_base: str | None
552570

553571
# Server
554572
host: str
@@ -847,6 +865,12 @@ def from_env(cls) -> "HindsightConfig":
847865
or os.getenv(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE),
848866
embeddings_litellm_api_key=os.getenv(ENV_EMBEDDINGS_LITELLM_API_KEY) or os.getenv(ENV_LITELLM_API_KEY),
849867
embeddings_litellm_model=os.getenv(ENV_EMBEDDINGS_LITELLM_MODEL, DEFAULT_EMBEDDINGS_LITELLM_MODEL),
868+
# LiteLLM SDK embeddings (direct API access)
869+
embeddings_litellm_sdk_api_key=os.getenv(ENV_EMBEDDINGS_LITELLM_SDK_API_KEY),
870+
embeddings_litellm_sdk_model=os.getenv(
871+
ENV_EMBEDDINGS_LITELLM_SDK_MODEL, DEFAULT_EMBEDDINGS_LITELLM_SDK_MODEL
872+
),
873+
embeddings_litellm_sdk_api_base=os.getenv(ENV_EMBEDDINGS_LITELLM_SDK_API_BASE) or None,
850874
# Reranker
851875
reranker_provider=os.getenv(ENV_RERANKER_PROVIDER, DEFAULT_RERANKER_PROVIDER),
852876
reranker_local_model=os.getenv(ENV_RERANKER_LOCAL_MODEL, DEFAULT_RERANKER_LOCAL_MODEL),
@@ -876,6 +900,10 @@ def from_env(cls) -> "HindsightConfig":
876900
or os.getenv(ENV_LITELLM_API_BASE, DEFAULT_LITELLM_API_BASE),
877901
reranker_litellm_api_key=os.getenv(ENV_RERANKER_LITELLM_API_KEY) or os.getenv(ENV_LITELLM_API_KEY),
878902
reranker_litellm_model=os.getenv(ENV_RERANKER_LITELLM_MODEL, DEFAULT_RERANKER_LITELLM_MODEL),
903+
# LiteLLM SDK reranker (direct API access)
904+
reranker_litellm_sdk_api_key=os.getenv(ENV_RERANKER_LITELLM_SDK_API_KEY),
905+
reranker_litellm_sdk_model=os.getenv(ENV_RERANKER_LITELLM_SDK_MODEL, DEFAULT_RERANKER_LITELLM_SDK_MODEL),
906+
reranker_litellm_sdk_api_base=os.getenv(ENV_RERANKER_LITELLM_SDK_API_BASE) or None,
879907
# Server
880908
host=os.getenv(ENV_HOST, DEFAULT_HOST),
881909
port=int(os.getenv(ENV_PORT, DEFAULT_PORT)),

hindsight-api/hindsight_api/engine/cross_encoder.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
DEFAULT_RERANKER_FLASHRANK_CACHE_DIR,
2222
DEFAULT_RERANKER_FLASHRANK_MODEL,
2323
DEFAULT_RERANKER_LITELLM_MODEL,
24+
DEFAULT_RERANKER_LITELLM_SDK_MODEL,
2425
DEFAULT_RERANKER_LOCAL_FORCE_CPU,
2526
DEFAULT_RERANKER_LOCAL_MAX_CONCURRENT,
2627
DEFAULT_RERANKER_LOCAL_MODEL,
@@ -32,6 +33,7 @@
3233
ENV_RERANKER_COHERE_MODEL,
3334
ENV_RERANKER_FLASHRANK_CACHE_DIR,
3435
ENV_RERANKER_FLASHRANK_MODEL,
36+
ENV_RERANKER_LITELLM_SDK_API_KEY,
3537
ENV_RERANKER_LOCAL_FORCE_CPU,
3638
ENV_RERANKER_LOCAL_MAX_CONCURRENT,
3739
ENV_RERANKER_LOCAL_MODEL,
@@ -828,6 +830,126 @@ async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
828830
return all_scores
829831

830832

833+
class LiteLLMSDKCrossEncoder(CrossEncoderModel):
834+
"""
835+
LiteLLM SDK cross-encoder for direct API integration.
836+
837+
Supports reranking via LiteLLM SDK without requiring a proxy server.
838+
Supported providers: Cohere, DeepInfra, Together AI, HuggingFace, Jina AI, Voyage AI, AWS Bedrock.
839+
840+
Example model names:
841+
- cohere/rerank-english-v3.0
842+
- deepinfra/Qwen3-reranker-8B
843+
- together_ai/Salesforce/Llama-Rank-V1
844+
- huggingface/BAAI/bge-reranker-v2-m3
845+
"""
846+
847+
def __init__(
848+
self,
849+
api_key: str,
850+
model: str = DEFAULT_RERANKER_LITELLM_SDK_MODEL,
851+
api_base: str | None = None,
852+
timeout: float = 60.0,
853+
):
854+
"""
855+
Initialize LiteLLM SDK cross-encoder client.
856+
857+
Args:
858+
api_key: API key for the reranking provider
859+
model: Model name with provider prefix (e.g., "deepinfra/Qwen3-reranker-8B")
860+
api_base: Custom base URL for API (optional)
861+
timeout: Request timeout in seconds (default: 60.0)
862+
"""
863+
self.api_key = api_key
864+
self.model = model
865+
self.api_base = api_base
866+
self.timeout = timeout
867+
self._initialized = False
868+
self._litellm = None # Will be set during initialization
869+
870+
@property
871+
def provider_name(self) -> str:
872+
return "litellm-sdk"
873+
874+
async def initialize(self) -> None:
875+
"""Initialize the LiteLLM SDK client."""
876+
if self._initialized:
877+
return
878+
879+
try:
880+
import litellm
881+
882+
self._litellm = litellm # Store reference
883+
except ImportError:
884+
raise ImportError("litellm is required for LiteLLMSDKCrossEncoder. Install it with: pip install litellm")
885+
886+
api_base_msg = f" at {self.api_base}" if self.api_base else ""
887+
logger.info(f"Reranker: initializing LiteLLM SDK provider with model {self.model}{api_base_msg}")
888+
889+
self._initialized = True
890+
logger.info("Reranker: LiteLLM SDK provider initialized")
891+
892+
async def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
893+
"""
894+
Score query-document pairs using the LiteLLM SDK.
895+
896+
Args:
897+
pairs: List of (query, document) tuples to score
898+
899+
Returns:
900+
List of relevance scores
901+
"""
902+
if not self._initialized:
903+
raise RuntimeError("Reranker not initialized. Call initialize() first.")
904+
905+
if not pairs:
906+
return []
907+
908+
# Group pairs by query for efficient batching
909+
# LiteLLM rerank expects one query with multiple documents
910+
query_groups: dict[str, list[tuple[int, str]]] = {}
911+
for idx, (query, text) in enumerate(pairs):
912+
if query not in query_groups:
913+
query_groups[query] = []
914+
query_groups[query].append((idx, text))
915+
916+
all_scores = [0.0] * len(pairs)
917+
918+
for query, indexed_texts in query_groups.items():
919+
texts = [text for _, text in indexed_texts]
920+
indices = [idx for idx, _ in indexed_texts]
921+
922+
# Build kwargs for rerank call
923+
rerank_kwargs = {
924+
"model": self.model,
925+
"query": query,
926+
"documents": texts,
927+
"api_key": self.api_key,
928+
}
929+
if self.api_base:
930+
rerank_kwargs["api_base"] = self.api_base
931+
932+
response = await self._litellm.arerank(**rerank_kwargs)
933+
934+
# Map scores back to original positions
935+
# Response format: RerankResponse with results list
936+
# Each result is a TypedDict with "index" and "relevance_score"
937+
if hasattr(response, "results") and response.results:
938+
for result in response.results:
939+
# Results are TypedDicts, use dict-style access
940+
original_idx = result["index"]
941+
score = result.get("relevance_score", result.get("score", 0.0))
942+
all_scores[indices[original_idx]] = score
943+
elif isinstance(response, list):
944+
# Direct list of scores (unlikely but defensive)
945+
for i, score in enumerate(response):
946+
all_scores[indices[i]] = score
947+
else:
948+
logger.warning(f"Unexpected response format from LiteLLM rerank: {type(response)}")
949+
950+
return all_scores
951+
952+
831953
def create_cross_encoder_from_env() -> CrossEncoderModel:
832954
"""
833955
Create a CrossEncoderModel instance based on configuration.
@@ -877,9 +999,20 @@ def create_cross_encoder_from_env() -> CrossEncoderModel:
877999
api_key=config.reranker_litellm_api_key,
8781000
model=config.reranker_litellm_model,
8791001
)
1002+
elif provider == "litellm-sdk":
1003+
api_key = config.reranker_litellm_sdk_api_key
1004+
if not api_key:
1005+
raise ValueError(
1006+
f"{ENV_RERANKER_LITELLM_SDK_API_KEY} is required when {ENV_RERANKER_PROVIDER} is 'litellm-sdk'"
1007+
)
1008+
return LiteLLMSDKCrossEncoder(
1009+
api_key=api_key,
1010+
model=config.reranker_litellm_sdk_model,
1011+
api_base=config.reranker_litellm_sdk_api_base,
1012+
)
8801013
elif provider == "rrf":
8811014
return RRFPassthroughCrossEncoder()
8821015
else:
8831016
raise ValueError(
884-
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'flashrank', 'litellm', 'rrf'"
1017+
f"Unknown reranker provider: {provider}. Supported: 'local', 'tei', 'cohere', 'flashrank', 'litellm', 'litellm-sdk', 'rrf'"
8851018
)

0 commit comments

Comments
 (0)