Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions config/system.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ llm_pool:
# Add, remove or override model specific parameters
temperature: null # Removes temperature from default
max_completion_tokens: 2048 # Overrides default
# Vertex AI example with per-model region:
# judge_vertex_gemini:
# provider: vertex_ai
# model: gemini-2.0-flash
# parameters:
# vertex_location: us-central1 # Region for this model
# # vertex_project: my-gcp-project # Optional: override GCP project
# judge_vertex_llama:
# provider: vertex_ai
# model: meta/llama-3.3-70b-instruct-maas
# parameters:
# vertex_location: europe-west1 # Different region for this model

# Judge Panel: multiple judges from the pool
# Combine their scores. First judge in judges is the fallback when the full panel is not used for a metric.
Expand Down
114 changes: 95 additions & 19 deletions src/lightspeed_evaluation/core/llm/litellm_patch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""LiteLLM configuration for token tracking and Ragas 0.4 compatibility.
"""LiteLLM configuration for token tracking, Ragas 0.4 compatibility, and Vertex AI support.

This module configures litellm for two purposes:
This module configures litellm for three purposes:

1. TOKEN TRACKING: Wraps litellm.completion, litellm.acompletion, litellm.embedding,
and litellm.aembedding to track token usage for all LLM and embedding calls.
Expand All @@ -14,14 +14,22 @@

We replace the LoggingWorker with a no-op implementation to avoid this.
This is safe because we don't use litellm's built-in observability features.

3. VERTEX AI PER-MODEL REGION SUPPORT: litellm.drop_params=True (set by
DeepEval) silently strips vertex_project and vertex_location from
completion kwargs. The completion wrappers intercept these params and
temporarily set them as litellm module-level attributes, which litellm
checks as a fallback in its vertex_ai handler.
"""

import asyncio
import logging
import os
import threading
import warnings
from contextlib import asynccontextmanager, contextmanager
from functools import wraps
from typing import Any
from typing import Any, AsyncGenerator, Generator

import litellm

Expand Down Expand Up @@ -89,6 +97,84 @@ def clear_queue(self) -> None:
litellm.suppress_debug_info = True


# =============================================================================
# GLOBAL STATE LOCKS
# =============================================================================
# Locks for ALL litellm global state mutations (cache, ssl_verify,
# vertex_project, vertex_location). Import the appropriate lock in any
# module that reads/writes litellm global state to prevent race conditions
# between concurrent pipelines.
# - litellm_state_lock: for synchronous code paths (threading.Lock)
# - litellm_state_async_lock: for asynchronous code paths (asyncio.Lock)
litellm_state_lock = threading.Lock()
litellm_state_async_lock = asyncio.Lock()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated


# =============================================================================
# VERTEX AI PER-MODEL REGION SUPPORT
# =============================================================================
# litellm.drop_params=True (set by DeepEval) silently strips vertex_project
# and vertex_location from completion kwargs. We intercept these params and
# temporarily set them as litellm module-level attributes, which litellm
# checks as a fallback in its vertex_ai handler.


@contextmanager
def _vertex_override(kwargs: dict[str, Any]) -> Generator[None, None, None]:
"""Pop vertex_project/vertex_location from kwargs and set as litellm module attrs.

Always acquires litellm_state_lock to prevent concurrent reads of partially
updated globals, even when no vertex params are present in kwargs.
"""
with litellm_state_lock:
vp = kwargs.pop("vertex_project", None)
vl = kwargs.pop("vertex_location", None)
if vp is None and vl is None:
yield
return
old_vp = getattr(litellm, "vertex_project", None)
old_vl = getattr(litellm, "vertex_location", None)
try:
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
yield
finally:
litellm.vertex_project = old_vp
litellm.vertex_location = old_vl


@asynccontextmanager
async def _vertex_override_async(
kwargs: dict[str, Any],
) -> AsyncGenerator[None, None]:
"""Async version of _vertex_override using asyncio.Lock.

Uses litellm_state_async_lock instead of threading.Lock to avoid blocking
the event loop. The lock is held across the yield (including any awaited
completion call) to ensure globals remain consistent for the duration of
the request.
"""
async with litellm_state_async_lock:
vp = kwargs.pop("vertex_project", None)
vl = kwargs.pop("vertex_location", None)
if vp is None and vl is None:
yield
return
old_vp = getattr(litellm, "vertex_project", None)
old_vl = getattr(litellm, "vertex_location", None)
try:
if vp is not None:
litellm.vertex_project = vp
if vl is not None:
litellm.vertex_location = vl
yield
finally:
litellm.vertex_project = old_vp
litellm.vertex_location = old_vl
Comment thread
coderabbitai[bot] marked this conversation as resolved.


# =============================================================================
# TOKEN TRACKING: Wrap completion and embedding functions
# =============================================================================
Expand All @@ -101,11 +187,11 @@ def clear_queue(self) -> None:
_original_aembedding = litellm.aembedding


# Patch litellm's completion functions to include token tracking
@wraps(_original_completion)
def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.completion that tracks tokens."""
response = _original_completion(*args, **kwargs)
"""Wrapper around litellm.completion that tracks tokens and handles Vertex params."""
with _vertex_override(kwargs):
response = _original_completion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
Expand All @@ -115,16 +201,16 @@ def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:

@wraps(_original_acompletion)
async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.acompletion that tracks tokens."""
response = await _original_acompletion(*args, **kwargs)
"""Wrapper around litellm.acompletion that tracks tokens and handles Vertex params."""
async with _vertex_override_async(kwargs):
response = await _original_acompletion(*args, **kwargs)
try:
track_judge_tokens(response)
except Exception as e: # pylint: disable=broad-exception-caught
logger.exception("Failed to track tokens for acompletion: %s", e)
return response


# Patch litellm's embedding functions to include token tracking
@wraps(_original_embedding)
def _embedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
"""Wrapper around litellm.embedding that tracks tokens."""
Expand All @@ -147,22 +233,12 @@ async def _aembedding_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
return response


# Patch litellm's completion and embedding functions to include token tracking
litellm.completion = _completion_with_token_tracking
litellm.acompletion = _acompletion_with_token_tracking
litellm.embedding = _embedding_with_token_tracking
litellm.aembedding = _aembedding_with_token_tracking


# =============================================================================
# GLOBAL STATE LOCK
# =============================================================================
# Single lock for ALL litellm global state mutations (cache, ssl_verify).
# Import this lock in any module that reads/writes litellm.cache or
# litellm.ssl_verify to prevent race conditions between concurrent pipelines.
litellm_state_lock = threading.Lock()


# =============================================================================
# SSL CONFIGURATION UTILITY
# =============================================================================
Expand Down
Loading
Loading