Skip to content

Commit 177960b

Browse files
committed
Fix bug with import_list and add unit tests to test
1 parent 18aa06f commit 177960b

2 files changed

Lines changed: 191 additions & 3 deletions

File tree

aikido_zen/storage/ai_statistics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def import_list(self, ai_stats_list):
4141
new_entry["provider"], new_entry["model"]
4242
)
4343
existing_entry["calls"] += new_entry["calls"]
44-
existing_entry["tokens"]["input"] = new_entry["tokens"]["input"]
45-
existing_entry["tokens"]["output"] = new_entry["tokens"]["output"]
46-
existing_entry["tokens"]["total"] = new_entry["tokens"]["total"]
44+
existing_entry["tokens"]["input"] += new_entry["tokens"]["input"]
45+
existing_entry["tokens"]["output"] += new_entry["tokens"]["output"]
46+
existing_entry["tokens"]["total"] += new_entry["tokens"]["total"]
4747

4848
def clear(self):
4949
self.calls.clear()

aikido_zen/storage/ai_statistics_test.py

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,3 +242,191 @@ def test_get_stats_consistency_after_multiple_calls(stats):
242242
},
243243
},
244244
]
245+
246+
247+
def test_import_list_with_empty_stats(stats):
248+
# Test importing a list into empty statistics
249+
stats.import_list(
250+
[
251+
{
252+
"provider": "openai",
253+
"model": "gpt-4",
254+
"calls": 1,
255+
"tokens": {
256+
"input": 100,
257+
"output": 50,
258+
"total": 150,
259+
},
260+
},
261+
{
262+
"provider": "anthropic",
263+
"model": "claude-3",
264+
"calls": 1,
265+
"tokens": {
266+
"input": 120,
267+
"output": 60,
268+
"total": 180,
269+
},
270+
},
271+
]
272+
)
273+
274+
result = stats.get_stats()
275+
assert len(result) == 2
276+
assert result[0] == {
277+
"provider": "openai",
278+
"model": "gpt-4",
279+
"calls": 1,
280+
"tokens": {
281+
"input": 100,
282+
"output": 50,
283+
"total": 150,
284+
},
285+
}
286+
assert result[1] == {
287+
"provider": "anthropic",
288+
"model": "claude-3",
289+
"calls": 1,
290+
"tokens": {
291+
"input": 120,
292+
"output": 60,
293+
"total": 180,
294+
},
295+
}
296+
297+
298+
def test_import_list_with_existing_stats(stats):
299+
# Add some initial statistics
300+
stats.on_ai_call(
301+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
302+
)
303+
304+
# Import a list that includes an existing provider:model combination
305+
stats.import_list(
306+
[
307+
{
308+
"provider": "openai",
309+
"model": "gpt-4",
310+
"calls": 1,
311+
"tokens": {
312+
"input": 200,
313+
"output": 100,
314+
"total": 300,
315+
},
316+
},
317+
{
318+
"provider": "anthropic",
319+
"model": "claude-3",
320+
"calls": 1,
321+
"tokens": {
322+
"input": 120,
323+
"output": 60,
324+
"total": 180,
325+
},
326+
},
327+
]
328+
)
329+
330+
result = stats.get_stats()
331+
assert len(result) == 2
332+
assert result[0] == {
333+
"provider": "openai",
334+
"model": "gpt-4",
335+
"calls": 2, # Initial call + imported call
336+
"tokens": {
337+
"input": 300, # 100 + 200
338+
"output": 150, # 50 + 100
339+
"total": 450, # 150 + 300
340+
},
341+
}
342+
assert result[1] == {
343+
"provider": "anthropic",
344+
"model": "claude-3",
345+
"calls": 1,
346+
"tokens": {
347+
"input": 120,
348+
"output": 60,
349+
"total": 180,
350+
},
351+
}
352+
353+
354+
def test_import_list_with_overlapping_and_new_entries(stats):
355+
# Add some initial statistics
356+
stats.on_ai_call(
357+
provider="openai", model="gpt-4", input_tokens=100, output_tokens=50
358+
)
359+
stats.on_ai_call(
360+
provider="anthropic", model="claude-3", input_tokens=120, output_tokens=60
361+
)
362+
363+
# Import a list that includes both existing and new provider:model combinations
364+
stats.import_list(
365+
[
366+
{
367+
"provider": "openai",
368+
"model": "gpt-4",
369+
"calls": 1,
370+
"tokens": {
371+
"input": 200,
372+
"output": 100,
373+
"total": 300,
374+
},
375+
},
376+
{
377+
"provider": "anthropic",
378+
"model": "claude-3",
379+
"calls": 1,
380+
"tokens": {
381+
"input": 120,
382+
"output": 60,
383+
"total": 180,
384+
},
385+
},
386+
{
387+
"provider": "mistral",
388+
"model": "mistral-7b",
389+
"calls": 1,
390+
"tokens": {
391+
"input": 150,
392+
"output": 75,
393+
"total": 225,
394+
},
395+
},
396+
]
397+
)
398+
399+
result = stats.get_stats()
400+
assert len(result) == 3
401+
result.sort(key=lambda x: f"{x['provider']}:{x['model']}")
402+
403+
assert result[0] == {
404+
"provider": "anthropic",
405+
"model": "claude-3",
406+
"calls": 2, # Initial call + imported call
407+
"tokens": {
408+
"input": 240, # 120 + 120
409+
"output": 120, # 60 + 60
410+
"total": 360, # 180 + 180
411+
},
412+
}
413+
assert result[1] == {
414+
"provider": "mistral",
415+
"model": "mistral-7b",
416+
"calls": 1,
417+
"tokens": {
418+
"input": 150,
419+
"output": 75,
420+
"total": 225,
421+
},
422+
}
423+
assert result[2] == {
424+
"provider": "openai",
425+
"model": "gpt-4",
426+
"calls": 2, # Initial call + imported call
427+
"tokens": {
428+
"input": 300, # 100 + 200
429+
"output": 150, # 50 + 100
430+
"total": 450, # 150 + 300
431+
},
432+
}

0 commit comments

Comments
 (0)