Skip to content
12 changes: 12 additions & 0 deletions config/system.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ llm_pool:
# Add, remove or override model specific parameters
temperature: null # Removes temperature from default
max_completion_tokens: 2048 # Overrides default
# Vertex AI example with per-model region:
# judge_vertex_gemini:
# provider: vertex_ai
# model: gemini-2.0-flash
# parameters:
# vertex_location: us-central1 # Region for this model
# # vertex_project: my-gcp-project # Optional: override GCP project
# judge_vertex_llama:
# provider: vertex_ai
# model: meta/llama-3.3-70b-instruct-maas
# parameters:
# vertex_location: europe-west1 # Different region for this model

# Judge Panel: multiple judges from the pool
# Combine their scores. First judge in judges is the fallback when the full panel is not used for a metric.
Expand Down
120 changes: 101 additions & 19 deletions src/lightspeed_evaluation/core/llm/litellm_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""LiteLLM configuration for token tracking and Ragas 0.4 compatibility.
"""LiteLLM configuration for token tracking, Ragas 0.4 compatibility, and Vertex AI support.

This module configures litellm for two purposes:
This module configures litellm for three purposes:

1. TOKEN TRACKING: Wraps litellm.completion, litellm.acompletion, litellm.embedding,
and litellm.aembedding to track token usage for all LLM and embedding calls.
Expand All @@ -14,14 +14,22 @@

We replace the LoggingWorker with a no-op implementation to avoid this.
This is safe because we don't use litellm's built-in observability features.

3. VERTEX AI PER-MODEL REGION SUPPORT: litellm.drop_params=True (set by
DeepEval) silently strips vertex_project and vertex_location from
completion kwargs. The completion wrappers intercept these params and
temporarily set them as litellm module-level attributes, which litellm
checks as a fallback in its vertex_ai handler.
"""

import asyncio
import logging
import os
import threading
import warnings
from contextlib import asynccontextmanager, contextmanager
from functools import wraps
from typing import Any
from typing import Any, AsyncGenerator, Generator

import litellm

Expand Down Expand Up @@ -89,6 +97,90 @@ def clear_queue(self) -> None:
litellm.suppress_debug_info = True


# =============================================================================
# GLOBAL STATE LOCK
# =============================================================================
# Single lock for ALL litellm global state mutations (cache, ssl_verify,
# vertex_project, vertex_location). Import this lock in any module that
# reads/writes litellm global state to prevent race conditions between
# concurrent pipelines. Both sync and async code paths share this lock;
# async callers use asyncio.to_thread so the event loop is never blocked.
litellm_state_lock = threading.Lock()


# =============================================================================
# VERTEX AI PER-MODEL REGION SUPPORT
# =============================================================================
# litellm.drop_params=True (set by DeepEval) silently strips vertex_project
# and vertex_location from completion kwargs. We intercept these params and
# temporarily set them as litellm module-level attributes, which litellm
# checks as a fallback in its vertex_ai handler.


@contextmanager
def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]:
"""Pop vertex_project/vertex_location from kwargs and set as litellm module attrs.

Always acquires litellm_state_lock to prevent concurrent reads of partially
updated globals, even when no vertex params are present in kwargs.
"""
with litellm_state_lock:
vp = kwargs.pop("vertex_project", None)
vl = kwargs.pop("vertex_location", None)
if vp is None and vl is None:
yield
return
old_vp = getattr(litellm, "vertex_project", None)
old_vl = getattr(litellm, "vertex_location", None)
try:
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
yield
finally:
litellm.vertex_project = old_vp
litellm.vertex_location = old_vl


@asynccontextmanager
async def _vertex_override_async(
kwargs: dict[str, Any],
) -> AsyncGenerator[None, None]:
"""Async version of _vertex_override using asyncio.to_thread for acquire.

Acquires litellm_state_lock before mutating globals and holds it across the
yield so no concurrent caller can see partially-updated state. Lock
acquire uses asyncio.to_thread to avoid blocking the event loop; release
is called directly since it is non-blocking.
Uses the same lock as the synchronous path to prevent races between sync
and async callers.
"""
await asyncio.to_thread(litellm_state_lock.acquire)
try:
vp = kwargs.pop("vertex_project", None)
vl = kwargs.pop("vertex_location", None)
if vp is None and vl is None:
litellm_state_lock.release()
yield
return
old_vp = getattr(litellm, "vertex_project", None)
old_vl = getattr(litellm, "vertex_location", None)
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
except BaseException:
litellm_state_lock.release()
raise
try:
yield
finally:
litellm.vertex_project = old_vp
litellm.vertex_location = old_vl
litellm_state_lock.release()


# =============================================================================
# TOKEN TRACKING: Wrap completion and embedding functions
# =============================================================================
Expand All @@ -101,11 +193,11 @@ def clear_queue(self) -> None:
_original_aembedding = litellm.aembedding


# Patch litellm's completion functions to include token tracking
@wraps(_original_completion)
def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.completion that tracks tokens."""
response = _original_completion(*args, **kwargs)
"""Wrapper around litellm.completion that tracks tokens and handles Vertex params."""
with _vertex_override(kwargs):
response = _original_completion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
Expand All @@ -115,16 +207,16 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:

@wraps(_original_acompletion)
async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.acompletion that tracks tokens."""
response = await _original_acompletion(*args, **kwargs)
"""Wrapper around litellm.acompletion that tracks tokens and handles Vertex params."""
async with _vertex_override_async(kwargs):
response = await _original_acompletion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception("Failed to track tokens for acompletion: %s", e)
return response


# Patch litellm's embedding functions to include token tracking
@wraps(_original_embedding)
def _embedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.embedding that tracks tokens."""
Expand All @@ -147,22 +239,12 @@ async def _aembedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
return response


# Patch litellm's completion and embedding functions to include token tracking
litellm.completion = _completion_with_token_tracking
litellm.acompletion = _acompletion_with_token_tracking
litellm.embedding = _embedding_with_token_tracking
litellm.aembedding = _aembedding_with_token_tracking


# =============================================================================
# GLOBAL STATE LOCK
# =============================================================================
# Single lock for ALL litellm global state mutations (cache, ssl_verify).
# Import this lock in any module that reads/writes litellm.cache or
# litellm.ssl_verify to prevent race conditions between concurrent pipelines.
litellm_state_lock = threading.Lock()


# =============================================================================
# SSL CONFIGURATION UTILITY
# =============================================================================
Expand Down
Loading
Loading