|
2 | 2 |
|
3 | 3 | """Unit tests for custom LLM classes.""" |
4 | 4 |
|
| 5 | +import threading |
| 6 | + |
5 | 7 | import pytest |
6 | 8 | from pytest_mock import MockerFixture |
7 | 9 |
|
|
12 | 14 | class TestTokenTracker: |
13 | 15 | """Tests for TokenTracker.""" |
14 | 16 |
|
15 | | - def test_token_callback_accumulates_tokens(self, mocker: MockerFixture) -> None: |
16 | | - """Test that token callback accumulates token counts.""" |
| 17 | + def test_add_tokens_accumulates(self) -> None: |
| 18 | + """Test that add_tokens accumulates token counts.""" |
17 | 19 | tracker = TokenTracker() |
18 | 20 |
|
19 | | - # Mock completion response with usage |
20 | | - mock_response = mocker.Mock() |
21 | | - mock_response.usage = mocker.Mock() |
22 | | - mock_response.usage.prompt_tokens = 10 |
23 | | - mock_response.usage.completion_tokens = 20 |
| 21 | + tracker.add_tokens(10, 20) |
| 22 | + tracker.add_tokens(5, 15) |
| 23 | + |
| 24 | + input_tokens, output_tokens = tracker.get_counts() |
| 25 | + assert input_tokens == 15 |
| 26 | + assert output_tokens == 35 |
| 27 | + |
| 28 | + def test_reset_clears_counts(self) -> None: |
| 29 | + """Test that reset clears token counts.""" |
| 30 | + tracker = TokenTracker() |
| 31 | + tracker.add_tokens(100, 200) |
24 | 32 |
|
25 | | - tracker._token_callback({}, mock_response, 0.0, 0.0) |
| 33 | + tracker.reset() |
26 | 34 |
|
27 | 35 | input_tokens, output_tokens = tracker.get_counts() |
28 | | - assert input_tokens == 10 |
29 | | - assert output_tokens == 20 |
| 36 | + assert input_tokens == 0 |
| 37 | + assert output_tokens == 0 |
| 38 | + |
| 39 | + def test_start_sets_active_tracker(self) -> None: |
| 40 | + """Test that start sets the tracker as active for current thread.""" |
| 41 | + tracker = TokenTracker() |
| 42 | + tracker.start() |
| 43 | + |
| 44 | + try: |
| 45 | + assert TokenTracker.get_active() is tracker |
| 46 | + finally: |
| 47 | + tracker.stop() |
| 48 | + |
| 49 | + def test_stop_clears_active_tracker(self) -> None: |
| 50 | + """Test that stop clears the active tracker.""" |
| 51 | + tracker = TokenTracker() |
| 52 | + tracker.start() |
| 53 | + tracker.stop() |
| 54 | + |
| 55 | + assert TokenTracker.get_active() is None |
| 56 | + |
| 57 | + def test_get_active_returns_none_when_no_tracker(self) -> None: |
| 58 | + """Test that get_active returns None when no tracker is active.""" |
| 59 | + # Ensure clean state by starting and stopping a tracker |
| 60 | + temp = TokenTracker() |
| 61 | + temp.start() |
| 62 | + temp.stop() |
| 63 | + |
| 64 | + assert TokenTracker.get_active() is None |
| 65 | + |
| 66 | + def test_thread_local_isolation(self) -> None: |
| 67 | + """Test that each thread has its own active tracker.""" |
| 68 | + tracker1 = TokenTracker() |
| 69 | + tracker2 = TokenTracker() |
| 70 | + results: dict[str, TokenTracker | None] = {} |
| 71 | + |
| 72 | + def thread_work(name: str, tracker: TokenTracker) -> None: |
| 73 | + tracker.start() |
| 74 | + results[name] = TokenTracker.get_active() |
| 75 | + # Deliberately don't stop to check isolation |
| 76 | + |
| 77 | + # Start tracker1 in main thread |
| 78 | + tracker1.start() |
| 79 | + |
| 80 | + # Start tracker2 in another thread |
| 81 | + thread = threading.Thread(target=thread_work, args=("thread2", tracker2)) |
| 82 | + thread.start() |
| 83 | + thread.join() |
| 84 | + |
| 85 | + # Main thread should still have tracker1 |
| 86 | + assert TokenTracker.get_active() is tracker1 |
| 87 | + # Other thread had tracker2 |
| 88 | + assert results["thread2"] is tracker2 |
| 89 | + |
| 90 | + tracker1.stop() |
| 91 | + |
| 92 | + def test_add_tokens_thread_safe(self) -> None: |
| 93 | + """Test that add_tokens is thread-safe under concurrent access.""" |
| 94 | + tracker = TokenTracker() |
| 95 | + num_threads = 10 |
| 96 | + tokens_per_thread = 100 |
| 97 | + |
| 98 | + def add_tokens_worker() -> None: |
| 99 | + for _ in range(tokens_per_thread): |
| 100 | + tracker.add_tokens(1, 2) |
| 101 | + |
| 102 | + threads = [ |
| 103 | + threading.Thread(target=add_tokens_worker) for _ in range(num_threads) |
| 104 | + ] |
| 105 | + for t in threads: |
| 106 | + t.start() |
| 107 | + for t in threads: |
| 108 | + t.join() |
| 109 | + |
| 110 | + input_tokens, output_tokens = tracker.get_counts() |
| 111 | + assert input_tokens == num_threads * tokens_per_thread |
| 112 | + assert output_tokens == num_threads * tokens_per_thread * 2 |
30 | 113 |
|
31 | 114 |
|
32 | 115 | class TestBaseCustomLLM: |
@@ -103,3 +186,63 @@ def test_call_raises_llm_error_on_failure(self, mocker: MockerFixture) -> None: |
103 | 186 |
|
104 | 187 | with pytest.raises(LLMError, match="LLM call failed"): |
105 | 188 | llm.call("test prompt") |
| 189 | + |
| 190 | + def test_call_captures_tokens_with_active_tracker( |
| 191 | + self, mocker: MockerFixture |
| 192 | + ) -> None: |
| 193 | + """Test call captures tokens when a TokenTracker is active.""" |
| 194 | + mock_litellm = mocker.patch("lightspeed_evaluation.core.llm.custom.litellm") |
| 195 | + mocker.patch.dict("os.environ", {}) |
| 196 | + |
| 197 | + # Mock response with usage |
| 198 | + mock_choice = mocker.Mock() |
| 199 | + mock_choice.message.content = "Test response" |
| 200 | + mock_response = mocker.Mock() |
| 201 | + mock_response.choices = [mock_choice] |
| 202 | + mock_response.usage = mocker.Mock() |
| 203 | + mock_response.usage.prompt_tokens = 50 |
| 204 | + mock_response.usage.completion_tokens = 100 |
| 205 | + mock_litellm.completion.return_value = mock_response |
| 206 | + |
| 207 | + # Start a tracker |
| 208 | + tracker = TokenTracker() |
| 209 | + tracker.start() |
| 210 | + |
| 211 | + try: |
| 212 | + llm = BaseCustomLLM("gpt-4", {"temperature": 0.0}) |
| 213 | + llm.call("test prompt") |
| 214 | + |
| 215 | + # Tokens should be captured |
| 216 | + input_tokens, output_tokens = tracker.get_counts() |
| 217 | + assert input_tokens == 50 |
| 218 | + assert output_tokens == 100 |
| 219 | + finally: |
| 220 | + tracker.stop() |
| 221 | + |
| 222 | + def test_call_does_not_capture_tokens_without_active_tracker( |
| 223 | + self, mocker: MockerFixture |
| 224 | + ) -> None: |
| 225 | + """Test call does not fail when no TokenTracker is active.""" |
| 226 | + mock_litellm = mocker.patch("lightspeed_evaluation.core.llm.custom.litellm") |
| 227 | + mocker.patch.dict("os.environ", {}) |
| 228 | + |
| 229 | + # Mock response with usage |
| 230 | + mock_choice = mocker.Mock() |
| 231 | + mock_choice.message.content = "Test response" |
| 232 | + mock_response = mocker.Mock() |
| 233 | + mock_response.choices = [mock_choice] |
| 234 | + mock_response.usage = mocker.Mock() |
| 235 | + mock_response.usage.prompt_tokens = 50 |
| 236 | + mock_response.usage.completion_tokens = 100 |
| 237 | + mock_litellm.completion.return_value = mock_response |
| 238 | + |
| 239 | + # Ensure no tracker is active |
| 240 | + temp = TokenTracker() |
| 241 | + temp.start() |
| 242 | + temp.stop() |
| 243 | + |
| 244 | + llm = BaseCustomLLM("gpt-4", {"temperature": 0.0}) |
| 245 | + result = llm.call("test prompt") |
| 246 | + |
| 247 | + # Should succeed without error |
| 248 | + assert result == "Test response" |
0 commit comments