Skip to content

Commit f399767

Browse files
authored
Merge pull request lightspeed-core#159 from bsatapat-jpg/dev
[LEADS-208] Fix TokenTracker double-counting in multi-thread evaluation
2 parents 3fe356c + 47293b9 commit f399767

3 files changed

Lines changed: 215 additions & 49 deletions

File tree

src/lightspeed_evaluation/core/llm/custom.py

Lines changed: 57 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import logging
5+
import threading
56
from typing import Any, Optional, Union
67

78
import litellm
@@ -11,62 +12,76 @@
1112

1213
logger = logging.getLogger(__name__)
1314

15+
# Thread-local storage for active TokenTracker
16+
_active_tracker: threading.local = threading.local()
17+
1418

1519
class TokenTracker:
16-
"""Tracks token usage from LiteLLM calls via callbacks.
20+
"""Tracks token usage from LLM calls using direct response extraction.
21+
22+
Uses thread-local storage to track the active tracker. Tokens are captured
23+
directly from litellm response in BaseCustomLLM.call() - no callbacks,
24+
no timeouts, no race conditions.
1725
1826
Usage:
1927
tracker = TokenTracker()
20-
tracker.start() # Register callback
21-
# ... make LLM calls ...
22-
tracker.stop() # Unregister callback
28+
tracker.start() # Set as active tracker for this thread
29+
# ... make LLM calls (tokens captured automatically) ...
30+
tracker.stop() # Unset as active tracker
2331
input_tokens, output_tokens = tracker.get_counts()
2432
"""
2533

2634
def __init__(self) -> None:
2735
"""Initialize token tracker."""
2836
self.input_tokens = 0
2937
self.output_tokens = 0
30-
self._callback_registered = False
38+
self._lock = threading.Lock() # Instance lock for token counter updates
3139

32-
def _token_callback(
33-
self,
34-
_kwargs: dict[str, Any],
35-
completion_response: Any,
36-
_start_time: float,
37-
_end_time: float,
38-
) -> None:
39-
"""Capture token usage from LiteLLM completion response."""
40-
if hasattr(completion_response, "usage") and completion_response.usage:
41-
usage = completion_response.usage
42-
self.input_tokens += getattr(usage, "prompt_tokens", 0)
43-
self.output_tokens += getattr(usage, "completion_tokens", 0)
40+
def add_tokens(self, prompt_tokens: int, completion_tokens: int) -> None:
41+
"""Add token counts (thread-safe).
42+
43+
Called by BaseCustomLLM.call() to record tokens from LLM response.
44+
45+
Args:
46+
prompt_tokens: Number of input/prompt tokens.
47+
completion_tokens: Number of output/completion tokens.
48+
"""
49+
with self._lock:
50+
self.input_tokens += prompt_tokens
51+
self.output_tokens += completion_tokens
4452

4553
def start(self) -> None:
46-
"""Register the token tracking callback."""
47-
if self._callback_registered:
48-
return
49-
if not hasattr(litellm, "success_callback") or litellm.success_callback is None:
50-
litellm.success_callback = []
51-
litellm.success_callback.append(self._token_callback)
52-
self._callback_registered = True
54+
"""Set this tracker as active for the current thread."""
55+
_active_tracker.tracker = self
5356

5457
def stop(self) -> None:
55-
"""Unregister the token tracking callback."""
56-
if not self._callback_registered:
57-
return
58-
if self._token_callback in litellm.success_callback:
59-
litellm.success_callback.remove(self._token_callback)
60-
self._callback_registered = False
58+
"""Unset this tracker as active for the current thread."""
59+
if getattr(_active_tracker, "tracker", None) is self:
60+
_active_tracker.tracker = None
6161

6262
def get_counts(self) -> tuple[int, int]:
63-
"""Get accumulated token counts."""
64-
return self.input_tokens, self.output_tokens
63+
"""Get accumulated token counts.
64+
65+
Returns:
66+
Tuple of (input_tokens, output_tokens)
67+
"""
68+
with self._lock:
69+
return self.input_tokens, self.output_tokens
6570

6671
def reset(self) -> None:
6772
"""Reset token counts to zero."""
68-
self.input_tokens = 0
69-
self.output_tokens = 0
73+
with self._lock:
74+
self.input_tokens = 0
75+
self.output_tokens = 0
76+
77+
@staticmethod
78+
def get_active() -> Optional["TokenTracker"]:
79+
"""Get the active tracker for the current thread.
80+
81+
Returns:
82+
The active TokenTracker, or None if no tracker is active.
83+
"""
84+
return getattr(_active_tracker, "tracker", None)
7085

7186

7287
class BaseCustomLLM: # pylint: disable=too-few-public-methods
@@ -133,6 +148,14 @@ def call(
133148
try:
134149
response = litellm.completion(**call_params)
135150

151+
# Direct token extraction - capture tokens synchronously from response
152+
tracker = TokenTracker.get_active()
153+
if tracker and hasattr(response, "usage") and response.usage:
154+
tracker.add_tokens(
155+
getattr(response.usage, "prompt_tokens", 0),
156+
getattr(response.usage, "completion_tokens", 0),
157+
)
158+
136159
# Extract content from all choices
137160
results = []
138161
for choice in response.choices: # type: ignore

tests/unit/core/llm/test_custom.py

Lines changed: 153 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
"""Unit tests for custom LLM classes."""
44

5+
import threading
6+
57
import pytest
68
from pytest_mock import MockerFixture
79

@@ -12,21 +14,102 @@
1214
class TestTokenTracker:
1315
"""Tests for TokenTracker."""
1416

15-
def test_token_callback_accumulates_tokens(self, mocker: MockerFixture) -> None:
16-
"""Test that token callback accumulates token counts."""
17+
def test_add_tokens_accumulates(self) -> None:
18+
"""Test that add_tokens accumulates token counts."""
1719
tracker = TokenTracker()
1820

19-
# Mock completion response with usage
20-
mock_response = mocker.Mock()
21-
mock_response.usage = mocker.Mock()
22-
mock_response.usage.prompt_tokens = 10
23-
mock_response.usage.completion_tokens = 20
21+
tracker.add_tokens(10, 20)
22+
tracker.add_tokens(5, 15)
23+
24+
input_tokens, output_tokens = tracker.get_counts()
25+
assert input_tokens == 15
26+
assert output_tokens == 35
27+
28+
def test_reset_clears_counts(self) -> None:
29+
"""Test that reset clears token counts."""
30+
tracker = TokenTracker()
31+
tracker.add_tokens(100, 200)
2432

25-
tracker._token_callback({}, mock_response, 0.0, 0.0)
33+
tracker.reset()
2634

2735
input_tokens, output_tokens = tracker.get_counts()
28-
assert input_tokens == 10
29-
assert output_tokens == 20
36+
assert input_tokens == 0
37+
assert output_tokens == 0
38+
39+
def test_start_sets_active_tracker(self) -> None:
40+
"""Test that start sets the tracker as active for current thread."""
41+
tracker = TokenTracker()
42+
tracker.start()
43+
44+
try:
45+
assert TokenTracker.get_active() is tracker
46+
finally:
47+
tracker.stop()
48+
49+
def test_stop_clears_active_tracker(self) -> None:
50+
"""Test that stop clears the active tracker."""
51+
tracker = TokenTracker()
52+
tracker.start()
53+
tracker.stop()
54+
55+
assert TokenTracker.get_active() is None
56+
57+
def test_get_active_returns_none_when_no_tracker(self) -> None:
58+
"""Test that get_active returns None when no tracker is active."""
59+
# Ensure clean state by starting and stopping a tracker
60+
temp = TokenTracker()
61+
temp.start()
62+
temp.stop()
63+
64+
assert TokenTracker.get_active() is None
65+
66+
def test_thread_local_isolation(self) -> None:
67+
"""Test that each thread has its own active tracker."""
68+
tracker1 = TokenTracker()
69+
tracker2 = TokenTracker()
70+
results: dict[str, TokenTracker | None] = {}
71+
72+
def thread_work(name: str, tracker: TokenTracker) -> None:
73+
tracker.start()
74+
results[name] = TokenTracker.get_active()
75+
# Deliberately don't stop to check isolation
76+
77+
# Start tracker1 in main thread
78+
tracker1.start()
79+
80+
# Start tracker2 in another thread
81+
thread = threading.Thread(target=thread_work, args=("thread2", tracker2))
82+
thread.start()
83+
thread.join()
84+
85+
# Main thread should still have tracker1
86+
assert TokenTracker.get_active() is tracker1
87+
# Other thread had tracker2
88+
assert results["thread2"] is tracker2
89+
90+
tracker1.stop()
91+
92+
def test_add_tokens_thread_safe(self) -> None:
93+
"""Test that add_tokens is thread-safe under concurrent access."""
94+
tracker = TokenTracker()
95+
num_threads = 10
96+
tokens_per_thread = 100
97+
98+
def add_tokens_worker() -> None:
99+
for _ in range(tokens_per_thread):
100+
tracker.add_tokens(1, 2)
101+
102+
threads = [
103+
threading.Thread(target=add_tokens_worker) for _ in range(num_threads)
104+
]
105+
for t in threads:
106+
t.start()
107+
for t in threads:
108+
t.join()
109+
110+
input_tokens, output_tokens = tracker.get_counts()
111+
assert input_tokens == num_threads * tokens_per_thread
112+
assert output_tokens == num_threads * tokens_per_thread * 2
30113

31114

32115
class TestBaseCustomLLM:
@@ -103,3 +186,63 @@ def test_call_raises_llm_error_on_failure(self, mocker: MockerFixture) -> None:
103186

104187
with pytest.raises(LLMError, match="LLM call failed"):
105188
llm.call("test prompt")
189+
190+
def test_call_captures_tokens_with_active_tracker(
191+
self, mocker: MockerFixture
192+
) -> None:
193+
"""Test call captures tokens when a TokenTracker is active."""
194+
mock_litellm = mocker.patch("lightspeed_evaluation.core.llm.custom.litellm")
195+
mocker.patch.dict("os.environ", {})
196+
197+
# Mock response with usage
198+
mock_choice = mocker.Mock()
199+
mock_choice.message.content = "Test response"
200+
mock_response = mocker.Mock()
201+
mock_response.choices = [mock_choice]
202+
mock_response.usage = mocker.Mock()
203+
mock_response.usage.prompt_tokens = 50
204+
mock_response.usage.completion_tokens = 100
205+
mock_litellm.completion.return_value = mock_response
206+
207+
# Start a tracker
208+
tracker = TokenTracker()
209+
tracker.start()
210+
211+
try:
212+
llm = BaseCustomLLM("gpt-4", {"temperature": 0.0})
213+
llm.call("test prompt")
214+
215+
# Tokens should be captured
216+
input_tokens, output_tokens = tracker.get_counts()
217+
assert input_tokens == 50
218+
assert output_tokens == 100
219+
finally:
220+
tracker.stop()
221+
222+
def test_call_does_not_capture_tokens_without_active_tracker(
223+
self, mocker: MockerFixture
224+
) -> None:
225+
"""Test call does not fail when no TokenTracker is active."""
226+
mock_litellm = mocker.patch("lightspeed_evaluation.core.llm.custom.litellm")
227+
mocker.patch.dict("os.environ", {})
228+
229+
# Mock response with usage
230+
mock_choice = mocker.Mock()
231+
mock_choice.message.content = "Test response"
232+
mock_response = mocker.Mock()
233+
mock_response.choices = [mock_choice]
234+
mock_response.usage = mocker.Mock()
235+
mock_response.usage.prompt_tokens = 50
236+
mock_response.usage.completion_tokens = 100
237+
mock_litellm.completion.return_value = mock_response
238+
239+
# Ensure no tracker is active
240+
temp = TokenTracker()
241+
temp.start()
242+
temp.stop()
243+
244+
llm = BaseCustomLLM("gpt-4", {"temperature": 0.0})
245+
result = llm.call("test prompt")
246+
247+
# Should succeed without error
248+
assert result == "Test response"

tests/unit/pipeline/evaluation/test_evaluator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -882,16 +882,16 @@ def test_token_tracker_start_stop(self) -> None:
882882
"""Test start and stop methods."""
883883
tracker = TokenTracker()
884884
tracker.start()
885-
assert tracker._callback_registered is True
885+
assert TokenTracker.get_active() is tracker
886886
tracker.stop()
887-
assert tracker._callback_registered is False
887+
assert TokenTracker.get_active() is None
888888

889889
def test_token_tracker_double_start(self) -> None:
890-
"""Test calling start twice doesn't register callback twice."""
890+
"""Test calling start twice doesn't fail."""
891891
tracker = TokenTracker()
892892
tracker.start()
893893
tracker.start() # Should not fail
894-
assert tracker._callback_registered is True
894+
assert TokenTracker.get_active() is tracker
895895
tracker.stop()
896896

897897
def test_token_tracker_double_stop(self) -> None:
@@ -900,7 +900,7 @@ def test_token_tracker_double_stop(self) -> None:
900900
tracker.start()
901901
tracker.stop()
902902
tracker.stop() # Should not fail
903-
assert tracker._callback_registered is False
903+
assert TokenTracker.get_active() is None
904904

905905
def test_token_tracker_independent_instances(self) -> None:
906906
"""Test multiple TokenTracker instances are independent."""

0 commit comments

Comments
 (0)