Skip to content

Commit 797dd65

Browse files
committed
Create new AIStatistics object with test cases
1 parent dac9412 commit 797dd65

2 files changed

Lines changed: 290 additions & 0 deletions

File tree

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import copy
2+
3+
4+
class AIStatistics:
5+
def __init__(self):
6+
self.calls = {}
7+
8+
def ensure_provider_stats(self, provider, model):
9+
key = get_provider_key(provider, model)
10+
11+
if key not in self.calls:
12+
self.calls[key] = {
13+
"provider": provider,
14+
"model": model,
15+
"calls": 0,
16+
"tokens": {
17+
"input": 0,
18+
"output": 0,
19+
"total": 0,
20+
},
21+
}
22+
23+
return self.calls[key]
24+
25+
def on_ai_call(self, provider, model, input_tokens, output_tokens):
26+
if not provider or not model:
27+
return
28+
29+
provider_stats = self.ensure_provider_stats(provider, model)
30+
provider_stats["calls"] += 1
31+
provider_stats["tokens"]["input"] += input_tokens
32+
provider_stats["tokens"]["output"] += output_tokens
33+
provider_stats["tokens"]["total"] += input_tokens + output_tokens
34+
35+
def get_stats(self):
36+
return [copy.deepcopy(stats) for stats in self.calls.values()]
37+
38+
def clear(self):
39+
self.calls.clear()
40+
41+
def is_empty(self):
42+
return len(self.calls) == 0
43+
44+
45+
def get_provider_key(provider, model):
46+
return f"{provider}:{model}"
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import pytest
2+
from .ai_statistics import AIStatistics
3+
4+
5+
@pytest.fixture
6+
def stats():
7+
return AIStatistics()
8+
9+
10+
def test_initializes_with_empty_state(stats):
11+
assert stats.get_stats() == []
12+
assert stats.is_empty() is True
13+
14+
15+
def test_tracks_basic_ai_calls(stats):
16+
stats.on_ai_call(
17+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
18+
)
19+
20+
result = stats.get_stats()
21+
assert len(result) == 1
22+
assert result[0] == {
23+
"provider": "openai",
24+
"model": "gpt-4",
25+
"calls": 1,
26+
"tokens": {
27+
"input": 100,
28+
"output": 50,
29+
"total": 150,
30+
},
31+
}
32+
33+
assert stats.is_empty() is False
34+
35+
36+
def test_tracks_multiple_calls_to_same_provider_model(stats):
37+
stats.on_ai_call(
38+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
39+
)
40+
41+
stats.on_ai_call(
42+
provider="openai", model="gpt-4", input_tokens=200, output_tokens=75
43+
)
44+
45+
result = stats.get_stats()
46+
assert len(result) == 1
47+
assert result[0] == {
48+
"provider": "openai",
49+
"model": "gpt-4",
50+
"calls": 2,
51+
"tokens": {
52+
"input": 300,
53+
"output": 125,
54+
"total": 425,
55+
},
56+
}
57+
58+
59+
def test_tracks_different_provider_model_combinations_separately(stats):
60+
stats.on_ai_call(
61+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
62+
)
63+
64+
stats.on_ai_call(
65+
provider="openai", model="gpt-3.5-turbo", input_tokens=80, output_tokens=40
66+
)
67+
68+
stats.on_ai_call(
69+
provider="anthropic", model="claude-3", input_tokens=120, output_tokens=60
70+
)
71+
72+
result = stats.get_stats()
73+
assert len(result) == 3
74+
75+
# Sort by provider:model for consistent testing
76+
result.sort(key=lambda x: f"{x['provider']}:{x['model']}")
77+
78+
assert result[0] == {
79+
"provider": "anthropic",
80+
"model": "claude-3",
81+
"calls": 1,
82+
"tokens": {
83+
"input": 120,
84+
"output": 60,
85+
"total": 180,
86+
},
87+
}
88+
89+
assert result[1] == {
90+
"provider": "openai",
91+
"model": "gpt-3.5-turbo",
92+
"calls": 1,
93+
"tokens": {
94+
"input": 80,
95+
"output": 40,
96+
"total": 120,
97+
},
98+
}
99+
100+
assert result[2] == {
101+
"provider": "openai",
102+
"model": "gpt-4",
103+
"calls": 1,
104+
"tokens": {
105+
"input": 100,
106+
"output": 50,
107+
"total": 150,
108+
},
109+
}
110+
111+
112+
def test_resets_all_statistics(stats):
113+
stats.on_ai_call(
114+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
115+
)
116+
117+
stats.on_ai_call(
118+
provider="anthropic", model="claude-3", input_tokens=120, output_tokens=60
119+
)
120+
121+
assert stats.is_empty() is False
122+
assert len(stats.get_stats()) == 2
123+
124+
stats.clear()
125+
126+
assert stats.is_empty() is True
127+
assert stats.get_stats() == []
128+
129+
130+
def test_handles_zero_token_inputs(stats):
131+
stats.on_ai_call(provider="openai", model="gpt-4", input_tokens=0, output_tokens=0)
132+
133+
result = stats.get_stats()
134+
assert len(result) == 1
135+
assert result[0]["tokens"] == {
136+
"input": 0,
137+
"output": 0,
138+
"total": 0,
139+
}
140+
141+
142+
def test_called_with_empty_provider(stats):
143+
stats.on_ai_call(provider="", model="gpt-4", input_tokens=100, output_tokens=50)
144+
145+
assert stats.is_empty() is True
146+
147+
148+
def test_called_with_empty_model(stats):
149+
stats.on_ai_call(provider="openai", model="", input_tokens=100, output_tokens=50)
150+
151+
assert stats.is_empty() is True
152+
153+
154+
def test_get_stats_returns_immutable_data(stats):
155+
stats.on_ai_call(
156+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
157+
)
158+
159+
result = stats.get_stats()
160+
result[0]["calls"] = 100
161+
result[0]["tokens"]["input"] = 1000
162+
163+
# Verify that the internal state has not changed
164+
assert stats.get_stats()[0]["calls"] == 1
165+
166+
167+
def test_get_stats_returns_new_list(stats):
168+
stats.on_ai_call(
169+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
170+
)
171+
172+
result1 = stats.get_stats()
173+
result2 = stats.get_stats()
174+
175+
# Modify the first result to ensure it doesn't affect the second result
176+
result1[0]["calls"] = 200
177+
178+
# Verify that the second result is unchanged
179+
assert result2 == [
180+
{
181+
"provider": "openai",
182+
"model": "gpt-4",
183+
"calls": 1,
184+
"tokens": {
185+
"input": 100,
186+
"output": 50,
187+
"total": 150,
188+
},
189+
}
190+
]
191+
192+
193+
def test_get_stats_returns_deep_copy(stats):
194+
stats.on_ai_call(
195+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
196+
)
197+
198+
result = stats.get_stats()
199+
200+
# Modify the result deeply to ensure it doesn't affect the internal state
201+
result[0]["tokens"]["input"] = 200
202+
203+
# Verify that the internal state has not changed
204+
assert stats.get_stats()[0]["tokens"]["input"] == 100
205+
206+
207+
def test_get_stats_consistency_after_multiple_calls(stats):
208+
stats.on_ai_call(
209+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
210+
)
211+
212+
stats.on_ai_call(
213+
provider="anthropic", model="claude-3", input_tokens=120, output_tokens=60
214+
)
215+
216+
result1 = stats.get_stats()
217+
result2 = stats.get_stats()
218+
219+
# Modify the first result to ensure it doesn't affect the second result
220+
result1[0]["calls"] = 300
221+
222+
# Verify that the second result is unchanged
223+
assert result2 == [
224+
{
225+
"provider": "openai",
226+
"model": "gpt-4",
227+
"calls": 1,
228+
"tokens": {
229+
"input": 100,
230+
"output": 50,
231+
"total": 150,
232+
},
233+
},
234+
{
235+
"provider": "anthropic",
236+
"model": "claude-3",
237+
"calls": 1,
238+
"tokens": {
239+
"input": 120,
240+
"output": 60,
241+
"total": 180,
242+
},
243+
},
244+
]

0 commit comments

Comments
 (0)