|
1 | 1 | """Base Custom LLM class for evaluation framework.""" |
2 | 2 |
|
3 | | -import os |
4 | 3 | import logging |
5 | | -import threading |
| 4 | +import os |
6 | 5 | from typing import Any, Optional, Union |
7 | 6 |
|
8 | 7 | import litellm |
|
12 | 11 |
|
13 | 12 | logger = logging.getLogger(__name__) |
14 | 13 |
|
15 | | -# Thread-local storage for active TokenTracker |
16 | | -_active_tracker: threading.local = threading.local() |
17 | | - |
18 | | - |
19 | | -class TokenTracker: |
20 | | - """Tracks token usage from LLM calls using direct response extraction. |
21 | | -
|
22 | | - Uses thread-local storage to track the active tracker. Tokens are captured |
23 | | - directly from litellm response in BaseCustomLLM.call() - no callbacks, |
24 | | - no timeouts, no race conditions. |
25 | | -
|
26 | | - Usage: |
27 | | - tracker = TokenTracker() |
28 | | - tracker.start() # Set as active tracker for this thread |
29 | | - # ... make LLM calls (tokens captured automatically) ... |
30 | | - tracker.stop() # Unset as active tracker |
31 | | - input_tokens, output_tokens = tracker.get_counts() |
32 | | - """ |
33 | | - |
34 | | - def __init__(self) -> None: |
35 | | - """Initialize token tracker.""" |
36 | | - self.input_tokens = 0 |
37 | | - self.output_tokens = 0 |
38 | | - self._lock = threading.Lock() # Instance lock for token counter updates |
39 | | - |
40 | | - def add_tokens(self, prompt_tokens: int, completion_tokens: int) -> None: |
41 | | - """Add token counts (thread-safe). |
42 | | -
|
43 | | - Called by BaseCustomLLM.call() to record tokens from LLM response. |
44 | | -
|
45 | | - Args: |
46 | | - prompt_tokens: Number of input/prompt tokens. |
47 | | - completion_tokens: Number of output/completion tokens. |
48 | | - """ |
49 | | - with self._lock: |
50 | | - self.input_tokens += prompt_tokens |
51 | | - self.output_tokens += completion_tokens |
52 | | - |
53 | | - def start(self) -> None: |
54 | | - """Set this tracker as active for the current thread.""" |
55 | | - _active_tracker.tracker = self |
56 | | - |
57 | | - def stop(self) -> None: |
58 | | - """Unset this tracker as active for the current thread.""" |
59 | | - if getattr(_active_tracker, "tracker", None) is self: |
60 | | - _active_tracker.tracker = None |
61 | | - |
62 | | - def get_counts(self) -> tuple[int, int]: |
63 | | - """Get accumulated token counts. |
64 | | -
|
65 | | - Returns: |
66 | | - Tuple of (input_tokens, output_tokens) |
67 | | - """ |
68 | | - with self._lock: |
69 | | - return self.input_tokens, self.output_tokens |
70 | | - |
71 | | - def reset(self) -> None: |
72 | | - """Reset token counts to zero.""" |
73 | | - with self._lock: |
74 | | - self.input_tokens = 0 |
75 | | - self.output_tokens = 0 |
76 | | - |
77 | | - @staticmethod |
78 | | - def get_active() -> Optional["TokenTracker"]: |
79 | | - """Get the active tracker for the current thread. |
80 | | -
|
81 | | - Returns: |
82 | | - The active TokenTracker, or None if no tracker is active. |
83 | | - """ |
84 | | - return getattr(_active_tracker, "tracker", None) |
85 | | - |
86 | 14 |
|
87 | 15 | class BaseCustomLLM: # pylint: disable=too-few-public-methods |
88 | 16 | """Base LLM class with core calling functionality.""" |
@@ -178,24 +106,3 @@ def call( |
178 | 106 |
|
179 | 107 | except Exception as e: |
180 | 108 | raise LLMError(f"LLM call failed: {str(e)}") from e |
181 | | - |
182 | | - finally: |
183 | | - # Track tokens even if the call failed - tokens may have been consumed |
184 | | - self._track_tokens(response) |
185 | | - |
186 | | - def _track_tokens(self, response: Any) -> None: |
187 | | - """Track JudgeLLM tokens if a tracker is active.""" |
188 | | - # Only track token counts if response exists and is NOT from cache |
189 | | - tracker = TokenTracker.get_active() |
190 | | - if tracker and response is not None: |
191 | | - cache_hit = getattr( |
192 | | - response, "_hidden_params", {} |
193 | | - ).get( # pylint: disable=protected-access |
194 | | - "cache_hit", False |
195 | | - ) |
196 | | - # Only add tokens if this response was not retrieved from cache |
197 | | - if not cache_hit and hasattr(response, "usage") and response.usage: |
198 | | - tracker.add_tokens( |
199 | | - getattr(response.usage, "prompt_tokens", 0), |
200 | | - getattr(response.usage, "completion_tokens", 0), |
201 | | - ) |
0 commit comments