Skip to content

Commit 9a29f5a

Browse files
committed
test for litellm
1 parent f915964 commit 9a29f5a

1 file changed

Lines changed: 229 additions & 0 deletions

File tree

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
"""Unit tests for LiteLLM instrumentation."""
2+
3+
import unittest
4+
from unittest.mock import Mock, patch, MagicMock
5+
import sys
6+
7+
# Mock litellm before importing our instrumentation
8+
sys.modules["litellm"] = MagicMock()
9+
10+
from agentops.instrumentation.providers.litellm import LiteLLMInstrumentor
11+
from agentops.instrumentation.providers.litellm.callback_handler import AgentOpsLiteLLMCallback
12+
from agentops.instrumentation.providers.litellm.utils import (
13+
detect_provider_from_model,
14+
extract_model_info,
15+
is_streaming_response,
16+
parse_litellm_error,
17+
)
18+
from agentops.instrumentation.providers.litellm.stream_wrapper import StreamWrapper, ChunkAggregator
19+
20+
21+
class TestLiteLLMUtils(unittest.TestCase):
22+
"""Test utility functions."""
23+
24+
def test_detect_provider_from_model(self):
25+
"""Test provider detection from model names."""
26+
test_cases = [
27+
("gpt-4", "openai"),
28+
("gpt-3.5-turbo", "openai"),
29+
("claude-3-opus-20240229", "anthropic"),
30+
("claude-2.1", "anthropic"),
31+
("command-nightly", "cohere"),
32+
("gemini-pro", "vertex_ai"),
33+
("llama-2-70b", "unknown"),
34+
("azure/gpt-4", "azure"),
35+
("bedrock/anthropic.claude-v2", "bedrock"),
36+
("unknown-model", "unknown"),
37+
]
38+
39+
for model, expected_provider in test_cases:
40+
with self.subTest(model=model):
41+
result = detect_provider_from_model(model)
42+
self.assertEqual(result, expected_provider)
43+
44+
def test_extract_model_info(self):
45+
"""Test model information extraction."""
46+
info = extract_model_info("gpt-4-turbo-32k")
47+
self.assertEqual(info["family"], "gpt-4")
48+
self.assertEqual(info["version"], "turbo")
49+
self.assertEqual(info["size"], "32k")
50+
51+
info = extract_model_info("claude-3-opus")
52+
self.assertEqual(info["family"], "claude-3")
53+
self.assertEqual(info["version"], "opus")
54+
55+
def test_is_streaming_response(self):
56+
"""Test streaming response detection."""
57+
58+
# Mock streaming response
59+
class MockStream:
60+
def __iter__(self):
61+
return self
62+
63+
def __next__(self):
64+
raise StopIteration
65+
66+
self.assertTrue(is_streaming_response(MockStream()))
67+
self.assertFalse(is_streaming_response({"choices": []}))
68+
self.assertFalse(is_streaming_response("not a stream"))
69+
70+
def test_parse_litellm_error(self):
71+
"""Test error parsing."""
72+
# Mock LiteLLM error
73+
error = Exception("Rate limit exceeded")
74+
error.status_code = 429
75+
error.llm_provider = "openai"
76+
77+
parsed = parse_litellm_error(error)
78+
self.assertEqual(parsed["type"], "Exception")
79+
self.assertEqual(parsed["error_category"], "rate_limit")
80+
self.assertEqual(parsed["status_code"], 429)
81+
self.assertEqual(parsed["llm_provider"], "openai")
82+
83+
84+
class TestChunkAggregator(unittest.TestCase):
85+
"""Test chunk aggregation for streaming."""
86+
87+
def test_aggregate_content(self):
88+
"""Test aggregating content from chunks."""
89+
aggregator = ChunkAggregator()
90+
91+
# Mock chunks
92+
chunks = [
93+
Mock(choices=[Mock(delta=Mock(content="Hello"))]),
94+
Mock(choices=[Mock(delta=Mock(content=" world"))]),
95+
Mock(choices=[Mock(delta=Mock(content="!"))]),
96+
]
97+
98+
for chunk in chunks:
99+
aggregator.add_chunk(chunk)
100+
101+
self.assertEqual(aggregator.get_aggregated_content(), "Hello world!")
102+
103+
def test_aggregate_function_calls(self):
104+
"""Test aggregating function calls from chunks."""
105+
aggregator = ChunkAggregator()
106+
107+
# Mock chunks with function call
108+
chunks = [
109+
Mock(choices=[Mock(delta=Mock(function_call=Mock(arguments='{"location":')))]),
110+
Mock(choices=[Mock(delta=Mock(function_call=Mock(arguments=' "San Francisco"}')))]),
111+
]
112+
113+
for chunk in chunks:
114+
aggregator.add_chunk(chunk)
115+
116+
self.assertEqual(aggregator.get_aggregated_function_call(), '{"location": "San Francisco"}')
117+
118+
119+
class TestCallbackHandler(unittest.TestCase):
120+
"""Test the callback handler."""
121+
122+
def setUp(self):
123+
"""Set up test fixtures."""
124+
self.instrumentor = Mock()
125+
self.handler = AgentOpsLiteLLMCallback(self.instrumentor)
126+
127+
@patch("agentops.instrumentation.providers.litellm.callback_handler.trace")
128+
def test_log_pre_api_call(self, mock_trace):
129+
"""Test pre-API call logging."""
130+
mock_span = Mock()
131+
mock_trace.get_current_span.return_value = mock_span
132+
mock_span.is_recording.return_value = True
133+
134+
messages = [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Hello"}]
135+
kwargs = {"temperature": 0.7, "max_tokens": 100, "litellm_call_id": "test-123"}
136+
137+
self.handler.log_pre_api_call("gpt-3.5-turbo", messages, kwargs)
138+
139+
# Verify span attributes were set
140+
mock_span.set_attribute.assert_any_call("llm.vendor", "litellm")
141+
mock_span.set_attribute.assert_any_call("llm.request.model", "gpt-3.5-turbo")
142+
mock_span.set_attribute.assert_any_call("llm.request.messages_count", 2)
143+
mock_span.set_attribute.assert_any_call("llm.request.temperature", 0.7)
144+
mock_span.set_attribute.assert_any_call("llm.request.max_tokens", 100)
145+
146+
@patch("agentops.instrumentation.providers.litellm.callback_handler.trace")
147+
def test_log_success_event(self, mock_trace):
148+
"""Test success event logging."""
149+
mock_span = Mock()
150+
mock_trace.get_current_span.return_value = mock_span
151+
mock_span.is_recording.return_value = True
152+
153+
# Mock response
154+
response = Mock()
155+
response.id = "chatcmpl-123"
156+
response.model = "gpt-3.5-turbo-0613"
157+
response.choices = [Mock(message=Mock(content="Hello there!"), finish_reason="stop")]
158+
response.usage = Mock(prompt_tokens=10, completion_tokens=5, total_tokens=15)
159+
160+
kwargs = {"litellm_call_id": "test-123"}
161+
162+
self.handler.log_success_event(kwargs, response, 1.0, 2.0)
163+
164+
# Verify response attributes
165+
mock_span.set_attribute.assert_any_call("llm.response.duration_seconds", 1.0)
166+
mock_span.set_attribute.assert_any_call("llm.response.id", "chatcmpl-123")
167+
mock_span.set_attribute.assert_any_call("llm.response.choices_count", 1)
168+
mock_span.set_attribute.assert_any_call("llm.usage.prompt_tokens", 10)
169+
mock_span.set_attribute.assert_any_call("llm.usage.completion_tokens", 5)
170+
mock_span.set_attribute.assert_any_call("llm.usage.total_tokens", 15)
171+
172+
173+
class TestStreamWrapper(unittest.TestCase):
174+
"""Test stream wrapper functionality."""
175+
176+
def test_stream_wrapper_basic(self):
177+
"""Test basic stream wrapper functionality."""
178+
# Mock stream
179+
chunks = ["chunk1", "chunk2", "chunk3"]
180+
mock_stream = iter(chunks)
181+
182+
# Mock span
183+
mock_span = Mock()
184+
185+
# Create wrapper
186+
wrapper = StreamWrapper(mock_stream, mock_span)
187+
188+
# Consume stream
189+
collected = list(wrapper)
190+
191+
self.assertEqual(collected, chunks)
192+
self.assertEqual(len(wrapper.chunks), 3)
193+
194+
# Verify time to first token was set
195+
mock_span.set_attribute.assert_any_call("llm.response.time_to_first_token", wrapper.first_chunk_time)
196+
197+
198+
class TestLiteLLMInstrumentor(unittest.TestCase):
199+
"""Test the main instrumentor class."""
200+
201+
def setUp(self):
202+
"""Set up test fixtures."""
203+
self.instrumentor = LiteLLMInstrumentor()
204+
205+
@patch("agentops.instrumentation.providers.litellm.instrumentor.logger")
206+
def test_instrument_not_available(self, mock_logger):
207+
"""Test instrumentation when LiteLLM is not available."""
208+
with patch.object(self.instrumentor, "_check_library_available", return_value=False):
209+
result = self.instrumentor.instrument()
210+
self.assertFalse(result)
211+
212+
@patch("sys.modules", {"litellm": Mock()})
213+
def test_register_callbacks(self):
214+
"""Test callback registration."""
215+
mock_litellm = Mock()
216+
mock_litellm.success_callback = None
217+
mock_litellm.failure_callback = None
218+
mock_litellm.start_callback = None
219+
220+
self.instrumentor._register_callbacks(mock_litellm)
221+
222+
# Verify callbacks were registered
223+
self.assertIn("agentops", mock_litellm.success_callback)
224+
self.assertIn("agentops", mock_litellm.failure_callback)
225+
self.assertIn("agentops", mock_litellm.start_callback)
226+
227+
228+
if __name__ == "__main__":
229+
unittest.main()

0 commit comments

Comments
 (0)