Skip to content

Commit 1e16a7f

Browse files
authored
LEADS-240: Token usage should be 0 for a re-run with successful cache (lightspeed-core#176)
* 0 JudgeLLM/API tokens are added when cache hit, added unit tests for such scenarios * Resolving rebase errors * Resolving pylint disable check
1 parent 0ede2b8 commit 1e16a7f

4 files changed

Lines changed: 101 additions & 9 deletions

File tree

src/lightspeed_evaluation/core/api/client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,14 @@ def _get_cached_response(self, request: APIRequest) -> APIResponse | None:
321321
if self.cache is None:
322322
raise RuntimeError("cache is None, but used")
323323
key = self._get_cache_key(request)
324-
return cast(APIResponse | None, self.cache.get(key))
324+
cached_response = cast(APIResponse | None, self.cache.get(key))
325+
326+
# Zero out token counts for cached responses since no API call was made
327+
if cached_response is not None:
328+
cached_response.input_tokens = 0
329+
cached_response.output_tokens = 0
330+
331+
return cached_response
325332

326333
def close(self) -> None:
327334
"""Close API client."""

src/lightspeed_evaluation/core/llm/custom.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,10 @@ def call(
145145
**kwargs,
146146
}
147147

148+
response = None
148149
try:
149150
response = litellm.completion(**call_params)
150151

151-
# Direct token extraction - capture tokens synchronously from response
152-
tracker = TokenTracker.get_active()
153-
if tracker and hasattr(response, "usage") and response.usage:
154-
tracker.add_tokens(
155-
getattr(response.usage, "prompt_tokens", 0),
156-
getattr(response.usage, "completion_tokens", 0),
157-
)
158-
159152
# Extract content from all choices
160153
results = []
161154
for choice in response.choices: # type: ignore
@@ -185,3 +178,24 @@ def call(
185178

186179
except Exception as e:
187180
raise LLMError(f"LLM call failed: {str(e)}") from e
181+
182+
finally:
183+
# Track tokens even if the call failed - tokens may have been consumed
184+
self._track_tokens(response)
185+
186+
def _track_tokens(self, response: Any) -> None:
187+
"""Track JudgeLLM tokens if a tracker is active."""
188+
# Only track token counts if response exists and is NOT from cache
189+
tracker = TokenTracker.get_active()
190+
if tracker and response is not None:
191+
cache_hit = getattr(
192+
response, "_hidden_params", {}
193+
).get( # pylint: disable=protected-access
194+
"cache_hit", False
195+
)
196+
# Only add tokens if this response was not retrieved from cache
197+
if not cache_hit and hasattr(response, "usage") and response.usage:
198+
tracker.add_tokens(
199+
getattr(response.usage, "prompt_tokens", 0),
200+
getattr(response.usage, "completion_tokens", 0),
201+
)

tests/unit/core/api/test_client.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,45 @@ def test_standard_endpoint_initialization(
496496

497497
assert client.config.endpoint_type == "query"
498498

499+
def test_get_cached_response_zeros_token_counts(
500+
self, basic_api_config_query_endpoint: APIConfig, mocker: MockerFixture
501+
) -> None:
502+
"""Test that _get_cached_response zeros out token counts."""
503+
basic_api_config_query_endpoint.cache_enabled = True
504+
505+
mocker.patch("lightspeed_evaluation.core.api.client.httpx.Client")
506+
507+
# Create a mock cache with a cached response that has token counts
508+
mock_cache = mocker.Mock()
509+
cached_response = APIResponse(
510+
response="Cached response",
511+
conversation_id="conv_123",
512+
input_tokens=50,
513+
output_tokens=100,
514+
)
515+
mock_cache.get.return_value = cached_response
516+
517+
mocker.patch(
518+
"lightspeed_evaluation.core.api.client.Cache", return_value=mock_cache
519+
)
520+
521+
client = APIClient(basic_api_config_query_endpoint)
522+
523+
# Prepare a request
524+
request = client._prepare_request("Test query")
525+
526+
# Get cached response
527+
result = client._get_cached_response(request)
528+
529+
# Verify token counts were zeroed
530+
assert result is not None
531+
assert result.input_tokens == 0
532+
assert result.output_tokens == 0
533+
534+
# Verify other fields remain unchanged
535+
assert result.response == "Cached response"
536+
assert result.conversation_id == "conv_123"
537+
499538

500539
class TestRetryLogic:
501540
"""Unit tests for retry logic in APIClient."""

tests/unit/core/llm/test_custom.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def test_call_captures_tokens_with_active_tracker(
202202
mock_response.usage = mocker.Mock()
203203
mock_response.usage.prompt_tokens = 50
204204
mock_response.usage.completion_tokens = 100
205+
mock_response._hidden_params = {} # Ensure no cache hit
205206
mock_litellm.completion.return_value = mock_response
206207

207208
# Start a tracker
@@ -246,3 +247,34 @@ def test_call_does_not_capture_tokens_without_active_tracker(
246247

247248
# Should succeed without error
248249
assert result == "Test response"
250+
251+
def test_call_does_not_add_tokens_on_cache_hit(self, mocker: MockerFixture) -> None:
252+
"""Test call does not add tokens when response is from cache."""
253+
mock_litellm = mocker.patch("lightspeed_evaluation.core.llm.custom.litellm")
254+
mocker.patch.dict("os.environ", {})
255+
256+
# Mock response with cache hit
257+
mock_choice = mocker.Mock()
258+
mock_choice.message.content = "Cached response"
259+
mock_response = mocker.Mock()
260+
mock_response.choices = [mock_choice]
261+
mock_response.usage = mocker.Mock()
262+
mock_response.usage.prompt_tokens = 50
263+
mock_response.usage.completion_tokens = 100
264+
mock_response._hidden_params = {"cache_hit": True} # Cache hit
265+
mock_litellm.completion.return_value = mock_response
266+
267+
# Start a tracker
268+
tracker = TokenTracker()
269+
tracker.start()
270+
271+
try:
272+
llm = BaseCustomLLM("gpt-4", {"temperature": 0.0})
273+
llm.call("test prompt")
274+
275+
# Tokens should NOT be captured due to cache hit
276+
input_tokens, output_tokens = tracker.get_counts()
277+
assert input_tokens == 0
278+
assert output_tokens == 0
279+
finally:
280+
tracker.stop()

0 commit comments

Comments
 (0)