Skip to content

Commit f6d5f8f

Browse files
authored
Merge pull request #180 from xmican10/LEADS-241-JudgeLLM-token-counter-for-Deepeval
[LEADS-241] judge llm token counter for deepeval
2 parents 5e8be40 + d09b606 commit f6d5f8f

9 files changed

Lines changed: 414 additions & 211 deletions

File tree

src/lightspeed_evaluation/core/llm/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22

33
from typing import TYPE_CHECKING
44

5+
# Apply litellm patching globally before any litellm usage in this package
6+
import lightspeed_evaluation.core.llm.litellm_patch # noqa: F401
7+
58
from lightspeed_evaluation.core.system.lazy_import import create_lazy_getattr
69

710
if TYPE_CHECKING:
811
# ruff: noqa: F401
9-
from lightspeed_evaluation.core.llm.custom import BaseCustomLLM, TokenTracker
12+
from lightspeed_evaluation.core.llm.custom import BaseCustomLLM
13+
from lightspeed_evaluation.core.llm.token_tracker import TokenTracker
1014
from lightspeed_evaluation.core.llm.deepeval import DeepEvalLLMManager
1115
from lightspeed_evaluation.core.llm.manager import LLMManager
1216
from lightspeed_evaluation.core.llm.ragas import RagasLLMManager
@@ -19,7 +23,7 @@
1923
"LLMError": ("lightspeed_evaluation.core.system.exceptions", "LLMError"),
2024
"LLMManager": ("lightspeed_evaluation.core.llm.manager", "LLMManager"),
2125
"BaseCustomLLM": ("lightspeed_evaluation.core.llm.custom", "BaseCustomLLM"),
22-
"TokenTracker": ("lightspeed_evaluation.core.llm.custom", "TokenTracker"),
26+
"TokenTracker": ("lightspeed_evaluation.core.llm.token_tracker", "TokenTracker"),
2327
"DeepEvalLLMManager": (
2428
"lightspeed_evaluation.core.llm.deepeval",
2529
"DeepEvalLLMManager",
Lines changed: 1 addition & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
"""Base Custom LLM class for evaluation framework."""
22

3-
import os
43
import logging
5-
import threading
4+
import os
65
from typing import Any, Optional, Union
76

87
import litellm
@@ -12,77 +11,6 @@
1211

1312
logger = logging.getLogger(__name__)
1413

15-
# Thread-local storage for active TokenTracker
16-
_active_tracker: threading.local = threading.local()
17-
18-
19-
class TokenTracker:
20-
"""Tracks token usage from LLM calls using direct response extraction.
21-
22-
Uses thread-local storage to track the active tracker. Tokens are captured
23-
directly from litellm response in BaseCustomLLM.call() - no callbacks,
24-
no timeouts, no race conditions.
25-
26-
Usage:
27-
tracker = TokenTracker()
28-
tracker.start() # Set as active tracker for this thread
29-
# ... make LLM calls (tokens captured automatically) ...
30-
tracker.stop() # Unset as active tracker
31-
input_tokens, output_tokens = tracker.get_counts()
32-
"""
33-
34-
def __init__(self) -> None:
35-
"""Initialize token tracker."""
36-
self.input_tokens = 0
37-
self.output_tokens = 0
38-
self._lock = threading.Lock() # Instance lock for token counter updates
39-
40-
def add_tokens(self, prompt_tokens: int, completion_tokens: int) -> None:
41-
"""Add token counts (thread-safe).
42-
43-
Called by BaseCustomLLM.call() to record tokens from LLM response.
44-
45-
Args:
46-
prompt_tokens: Number of input/prompt tokens.
47-
completion_tokens: Number of output/completion tokens.
48-
"""
49-
with self._lock:
50-
self.input_tokens += prompt_tokens
51-
self.output_tokens += completion_tokens
52-
53-
def start(self) -> None:
54-
"""Set this tracker as active for the current thread."""
55-
_active_tracker.tracker = self
56-
57-
def stop(self) -> None:
58-
"""Unset this tracker as active for the current thread."""
59-
if getattr(_active_tracker, "tracker", None) is self:
60-
_active_tracker.tracker = None
61-
62-
def get_counts(self) -> tuple[int, int]:
63-
"""Get accumulated token counts.
64-
65-
Returns:
66-
Tuple of (input_tokens, output_tokens)
67-
"""
68-
with self._lock:
69-
return self.input_tokens, self.output_tokens
70-
71-
def reset(self) -> None:
72-
"""Reset token counts to zero."""
73-
with self._lock:
74-
self.input_tokens = 0
75-
self.output_tokens = 0
76-
77-
@staticmethod
78-
def get_active() -> Optional["TokenTracker"]:
79-
"""Get the active tracker for the current thread.
80-
81-
Returns:
82-
The active TokenTracker, or None if no tracker is active.
83-
"""
84-
return getattr(_active_tracker, "tracker", None)
85-
8614

8715
class BaseCustomLLM: # pylint: disable=too-few-public-methods
8816
"""Base LLM class with core calling functionality."""
@@ -178,24 +106,3 @@ def call(
178106

179107
except Exception as e:
180108
raise LLMError(f"LLM call failed: {str(e)}") from e
181-
182-
finally:
183-
# Track tokens even if the call failed - tokens may have been consumed
184-
self._track_tokens(response)
185-
186-
def _track_tokens(self, response: Any) -> None:
187-
"""Track JudgeLLM tokens if a tracker is active."""
188-
# Only track token counts if response exists and is NOT from cache
189-
tracker = TokenTracker.get_active()
190-
if tracker and response is not None:
191-
cache_hit = getattr(
192-
response, "_hidden_params", {}
193-
).get( # pylint: disable=protected-access
194-
"cache_hit", False
195-
)
196-
# Only add tokens if this response was not retrieved from cache
197-
if not cache_hit and hasattr(response, "usage") and response.usage:
198-
tracker.add_tokens(
199-
getattr(response.usage, "prompt_tokens", 0),
200-
getattr(response.usage, "completion_tokens", 0),
201-
)

src/lightspeed_evaluation/core/llm/deepeval.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
"""DeepEval LLM Manager - DeepEval-specific LLM wrapper."""
1+
"""DeepEval LLM Manager - DeepEval-specific LLM wrapper.
2+
3+
Note: litellm patching is applied at package level (__init__.py) before any imports.
4+
This ensures DeepEval's LiteLLMModel uses the patched completion functions.
5+
"""
26

37
import os
48
from typing import Any
@@ -10,7 +14,8 @@
1014
class DeepEvalLLMManager:
1115
"""DeepEval LLM Manager - Takes LLM parameters directly.
1216
13-
This manager focuses solely on DeepEval-specific LLM integration.
17+
This manager focuses solely on DeepEval-specific LLM integration
18+
with token tracking support.
1419
"""
1520

1621
def __init__(self, model_name: str, llm_params: dict[str, Any]):
@@ -23,7 +28,10 @@ def __init__(self, model_name: str, llm_params: dict[str, Any]):
2328
# Always drop unsupported parameters for cross-provider compatibility
2429
litellm.drop_params = True
2530

26-
# Create DeepEval's LLM model with provided parameters
31+
# Note: Token tracking is handled by the patched litellm.completion/acompletion
32+
# No additional setup needed - the patch was applied at module import time
33+
34+
# Create standard LiteLLMModel - it will use our patched completion functions
2735
self.llm_model = LiteLLMModel(
2836
model=self.model_name,
2937
temperature=llm_params.get("temperature", 0.0),
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Global litellm patching for token tracking.
2+
3+
It patches litellm.completion and litellm.acompletion to automatically track tokens
4+
for all LLM calls throughout the application.
5+
"""
6+
7+
import logging
8+
from functools import wraps
9+
from typing import Any
10+
11+
import litellm
12+
13+
from lightspeed_evaluation.core.llm.token_tracker import track_tokens
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
# Store original functions before patching
19+
_original_completion = litellm.completion
20+
_original_acompletion = litellm.acompletion
21+
22+
23+
@wraps(_original_completion)
24+
def _completion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
25+
"""Wrapper around litellm.completion that tracks tokens."""
26+
response = _original_completion(*args, **kwargs)
27+
try:
28+
track_tokens(response)
29+
except Exception as e: # pylint: disable=broad-exception-caught
30+
logger.exception("Failed to track tokens for completion: %s", e)
31+
return response
32+
33+
34+
@wraps(_original_acompletion)
35+
async def _acompletion_with_token_tracking(*args: Any, **kwargs: Any) -> Any:
36+
"""Wrapper around litellm.acompletion that tracks tokens."""
37+
response = await _original_acompletion(*args, **kwargs)
38+
try:
39+
track_tokens(response)
40+
except Exception as e: # pylint: disable=broad-exception-caught
41+
logger.exception("Failed to track tokens for acompletion: %s", e)
42+
return response
43+
44+
45+
# Patch litellm's completion functions to include token tracking
46+
litellm.completion = _completion_with_token_tracking
47+
litellm.acompletion = _acompletion_with_token_tracking
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""TokenTracker for tracking LLM token usage with direct response extraction."""
2+
3+
import logging
4+
import threading
5+
from typing import Any, Optional
6+
7+
logger = logging.getLogger(__name__)
8+
9+
# Thread-local storage for active TokenTracker
10+
_active_tracker: threading.local = threading.local()
11+
12+
13+
class TokenTracker:
14+
"""Tracks token usage from LLM calls using direct response extraction.
15+
16+
Uses thread-local storage to track the active tracker. Tokens are captured
17+
directly from litellm response in BaseCustomLLM.call() - no callbacks,
18+
no timeouts, no race conditions.
19+
20+
Usage:
21+
tracker = TokenTracker()
22+
tracker.start() # Set as active tracker for this thread
23+
# ... make LLM calls (tokens captured automatically) ...
24+
tracker.stop() # Unset as active tracker
25+
input_tokens, output_tokens = tracker.get_counts()
26+
"""
27+
28+
def __init__(self) -> None:
29+
"""Initialize token tracker."""
30+
self.input_tokens = 0
31+
self.output_tokens = 0
32+
self._lock = threading.Lock() # Instance lock for token counter updates
33+
34+
def add_tokens(self, prompt_tokens: int, completion_tokens: int) -> None:
35+
"""Add token counts (thread-safe).
36+
37+
Called by BaseCustomLLM.call() to record tokens from LLM response.
38+
39+
Args:
40+
prompt_tokens: Number of input/prompt tokens.
41+
completion_tokens: Number of output/completion tokens.
42+
"""
43+
with self._lock:
44+
self.input_tokens += prompt_tokens
45+
self.output_tokens += completion_tokens
46+
47+
def start(self) -> None:
48+
"""Set this tracker as active for the current thread."""
49+
_active_tracker.tracker = self
50+
51+
def stop(self) -> None:
52+
"""Unset this tracker as active for the current thread."""
53+
if getattr(_active_tracker, "tracker", None) is self:
54+
_active_tracker.tracker = None
55+
56+
def get_counts(self) -> tuple[int, int]:
57+
"""Get accumulated token counts.
58+
59+
Returns:
60+
Tuple of (input_tokens, output_tokens)
61+
"""
62+
with self._lock:
63+
return self.input_tokens, self.output_tokens
64+
65+
def reset(self) -> None:
66+
"""Reset token counts to zero."""
67+
with self._lock:
68+
self.input_tokens = 0
69+
self.output_tokens = 0
70+
71+
@staticmethod
72+
def get_active() -> Optional["TokenTracker"]:
73+
"""Get the active tracker for the current thread.
74+
75+
Returns:
76+
The active TokenTracker, or None if no tracker is active.
77+
"""
78+
return getattr(_active_tracker, "tracker", None)
79+
80+
81+
def track_tokens(response: Any) -> None:
82+
"""Track JudgeLLM tokens if a tracker is active.
83+
84+
Called by the litellm patch (see llm_patch.py) after each LLM call.
85+
Skips tracking for cached responses to avoid counting tokens that weren't actually consumed.
86+
"""
87+
# Only track token counts if response exists and is NOT from cache
88+
tracker = TokenTracker.get_active()
89+
if tracker and response is not None:
90+
cache_hit = getattr(
91+
response, "_hidden_params", {}
92+
).get( # pylint: disable=protected-access
93+
"cache_hit", False
94+
)
95+
# Only add tokens if this response was not retrieved from cache
96+
if not cache_hit and hasattr(response, "usage") and response.usage:
97+
prompt_tokens = int(getattr(response.usage, "prompt_tokens", 0))
98+
completion_tokens = int(getattr(response.usage, "completion_tokens", 0))
99+
tracker.add_tokens(
100+
prompt_tokens,
101+
completion_tokens,
102+
)

src/lightspeed_evaluation/pipeline/evaluation/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, Optional
77

88
from lightspeed_evaluation.core.embedding.manager import EmbeddingManager
9-
from lightspeed_evaluation.core.llm.custom import TokenTracker
9+
from lightspeed_evaluation.core.llm.token_tracker import TokenTracker
1010
from lightspeed_evaluation.core.llm.manager import LLMManager
1111
from lightspeed_evaluation.core.metrics.custom import CustomMetrics
1212
from lightspeed_evaluation.core.metrics.deepeval import DeepEvalMetrics

0 commit comments

Comments
 (0)