Skip to content

Commit cdadaad

Browse files
chore: change return type of on_conversation_end from UsageTrackingCallback from dict to new `` type
1 parent d90a3ac commit cdadaad

3 files changed

Lines changed: 106 additions & 58 deletions

File tree

src/askui/models/shared/usage_tracking_callback.py

Lines changed: 56 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,48 @@
55
from typing import TYPE_CHECKING
66

77
from opentelemetry import trace
8+
from pydantic import BaseModel
89
from typing_extensions import override
910

1011
from askui.models.shared.agent_message_param import UsageParam
1112
from askui.models.shared.conversation_callback import ConversationCallback
12-
from askui.reporting import NULL_REPORTER, Reporter
13+
from askui.reporting import NULL_REPORTER
1314

1415
if TYPE_CHECKING:
1516
from askui.models.shared.conversation import Conversation
17+
from askui.reporting import Reporter
1618
from askui.speaker.speaker import SpeakerResult
1719
from askui.utils.model_pricing import ModelPricing
1820

1921

22+
class UsageSummary(BaseModel):
23+
"""Accumulated token usage and optional cost breakdown for a conversation.
24+
25+
Args:
26+
input_tokens (int | None): Total input tokens sent to the API.
27+
output_tokens (int | None): Total output tokens generated.
28+
cache_creation_input_tokens (int | None): Tokens used for cache creation.
29+
cache_read_input_tokens (int | None): Tokens read from cache.
30+
input_cost (float | None): Computed input cost in `currency`.
31+
output_cost (float | None): Computed output cost in `currency`.
32+
total_cost (float | None): Sum of `input_cost` and `output_cost`.
33+
currency (str | None): ISO 4217 currency code (e.g. ``"USD"``).
34+
input_cost_per_million_tokens (float | None): Rate used to compute `input_cost`.
35+
output_cost_per_million_tokens (float|None): Rate used to compute `output_cost`.
36+
"""
37+
38+
input_tokens: int | None = None
39+
output_tokens: int | None = None
40+
cache_creation_input_tokens: int | None = None
41+
cache_read_input_tokens: int | None = None
42+
input_cost: float | None = None
43+
output_cost: float | None = None
44+
total_cost: float | None = None
45+
currency: str | None = None
46+
input_cost_per_million_tokens: float | None = None
47+
output_cost_per_million_tokens: float | None = None
48+
49+
2050
class UsageTrackingCallback(ConversationCallback):
2151
"""Tracks token usage per step and reports a summary at conversation end.
2252
@@ -51,27 +81,40 @@ def on_step_end(
5181

5282
@override
5383
def on_conversation_end(self, conversation: Conversation) -> None:
54-
usage_dict = self._accumulated_usage.model_dump()
84+
input_cost: float | None = None
85+
output_cost: float | None = None
86+
total_cost: float | None = None
87+
currency: str | None = None
88+
input_cost_per_million_tokens: float | None = None
89+
output_cost_per_million_tokens: float | None = None
5590
if self._pricing is not None:
5691
input_tokens = self._accumulated_usage.input_tokens or 0
5792
output_tokens = self._accumulated_usage.output_tokens or 0
5893
input_cost = (
59-
input_tokens * self._pricing.input_cost_per_million_tokens / 1e7
94+
input_tokens * self._pricing.input_cost_per_million_tokens / 1e6
6095
)
6196
output_cost = (
62-
output_tokens * self._pricing.output_cost_per_million_tokens / 1e7
97+
output_tokens * self._pricing.output_cost_per_million_tokens / 1e6
6398
)
64-
usage_dict["input_cost"] = input_cost
65-
usage_dict["output_cost"] = output_cost
66-
usage_dict["total_cost"] = input_cost + output_cost
67-
usage_dict["currency"] = self._pricing.currency
68-
usage_dict["input_cost_per_million_tokens"] = (
69-
self._pricing.input_cost_per_million_tokens
70-
)
71-
usage_dict["output_cost_per_million_tokens"] = (
99+
total_cost = input_cost + output_cost
100+
currency = self._pricing.currency
101+
input_cost_per_million_tokens = self._pricing.input_cost_per_million_tokens
102+
output_cost_per_million_tokens = (
72103
self._pricing.output_cost_per_million_tokens
73104
)
74-
self._reporter.add_usage_summary(usage_dict)
105+
summary = UsageSummary(
106+
input_tokens=self._accumulated_usage.input_tokens,
107+
output_tokens=self._accumulated_usage.output_tokens,
108+
cache_creation_input_tokens=self._accumulated_usage.cache_creation_input_tokens,
109+
cache_read_input_tokens=self._accumulated_usage.cache_read_input_tokens,
110+
input_cost=input_cost,
111+
output_cost=output_cost,
112+
total_cost=total_cost,
113+
currency=currency,
114+
input_cost_per_million_tokens=input_cost_per_million_tokens,
115+
output_cost_per_million_tokens=output_cost_per_million_tokens,
116+
)
117+
self._reporter.add_usage_summary(summary)
75118

76119
@property
77120
def accumulated_usage(self) -> UsageParam:

src/askui/reporting.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import base64
24
import io
35
import json
@@ -9,14 +11,18 @@
911
from importlib.metadata import distributions
1012
from io import BytesIO
1113
from pathlib import Path
12-
from typing import Any, Optional, Union
14+
from typing import TYPE_CHECKING, Any, Optional, Union
1315

1416
from jinja2 import Template
15-
from PIL import Image
1617
from typing_extensions import TypedDict, override
1718

1819
from askui.utils.annotated_image import AnnotatedImage
1920

21+
if TYPE_CHECKING:
22+
from PIL import Image
23+
24+
from askui.models.shared.usage_tracking_callback import UsageSummary
25+
2026

2127
def normalize_to_pil_images(
2228
image: Image.Image | list[Image.Image] | AnnotatedImage | None,
@@ -80,15 +86,14 @@ def add_message(
8086
raise NotImplementedError
8187

8288
@abstractmethod
83-
def add_usage_summary(self, usage: dict[str, int | None]) -> None:
89+
def add_usage_summary(self, usage: UsageSummary) -> None:
8490
"""Add usage statistics summary to the report.
8591
86-
Called at the end of an act() execution with accumulated token usage.
92+
Called at the end of an ``act()`` execution with accumulated token
93+
usage and optional cost breakdown.
8794
8895
Args:
89-
usage (dict[str, int | None]): Accumulated usage statistics containing:
90-
- input_tokens: Total input tokens sent to API
91-
- output_tokens: Total output tokens generated
96+
usage (UsageSummary): Accumulated usage statistics.
9297
"""
9398
raise NotImplementedError
9499

@@ -134,7 +139,7 @@ def add_message(
134139
pass
135140

136141
@override
137-
def add_usage_summary(self, usage: dict[str, int | None]) -> None:
142+
def add_usage_summary(self, usage: UsageSummary) -> None:
138143
pass
139144

140145
@override
@@ -177,7 +182,7 @@ def add_message(
177182
reporter.add_message(role, content, image)
178183

179184
@override
180-
def add_usage_summary(self, usage: dict[str, int | None]) -> None:
185+
def add_usage_summary(self, usage: UsageSummary) -> None:
181186
"""Add usage summary to all reporters."""
182187
for reporter in self._reporters:
183188
reporter.add_usage_summary(usage)
@@ -215,7 +220,7 @@ def __init__(self, report_dir: str = "reports") -> None:
215220
self.report_dir = Path(report_dir)
216221
self.messages: list[dict[str, Any]] = []
217222
self.system_info = self._collect_system_info()
218-
self.usage_summary: dict[str, int | None] | None = None
223+
self.usage_summary: UsageSummary | None = None
219224
self.cache_original_usage: dict[str, int | None] | None = None
220225
self._start_time: datetime | None = None
221226

@@ -264,7 +269,7 @@ def add_message(
264269
self.messages.append(message)
265270

266271
@override
267-
def add_usage_summary(self, usage: dict[str, int | None]) -> None:
272+
def add_usage_summary(self, usage: UsageSummary) -> None:
268273
"""Store usage summary for inclusion in the report."""
269274
self.usage_summary = usage
270275

@@ -790,14 +795,14 @@ def generate(self) -> None:
790795
</tr>
791796
{% endif %}
792797
{% if usage_summary is not none %}
793-
{% if usage_summary.get('input_tokens') is not none %}
798+
{% if usage_summary.input_tokens is not none %}
794799
<tr>
795800
<th>Input Tokens</th>
796801
<td>
797-
{{ "{:,}".format(usage_summary.get('input_tokens')) }}
802+
{{ "{:,}".format(usage_summary.input_tokens) }}
798803
{% if cache_original_usage and cache_original_usage.get('input_tokens') %}
799804
{% set original = cache_original_usage.get('input_tokens') %}
800-
{% set current = usage_summary.get('input_tokens') %}
805+
{% set current = usage_summary.input_tokens %}
801806
{% set saved = original - current %}
802807
{% if saved > 0 and original > 0 %}
803808
{% set savings_pct = (saved / original * 100) %}
@@ -807,14 +812,14 @@ def generate(self) -> None:
807812
</td>
808813
</tr>
809814
{% endif %}
810-
{% if usage_summary.get('output_tokens') is not none %}
815+
{% if usage_summary.output_tokens is not none %}
811816
<tr>
812817
<th>Output Tokens</th>
813818
<td>
814-
{{ "{:,}".format(usage_summary.get('output_tokens')) }}
819+
{{ "{:,}".format(usage_summary.output_tokens) }}
815820
{% if cache_original_usage and cache_original_usage.get('output_tokens') %}
816821
{% set original = cache_original_usage.get('output_tokens') %}
817-
{% set current = usage_summary.get('output_tokens') %}
822+
{% set current = usage_summary.output_tokens %}
818823
{% set saved = original - current %}
819824
{% if saved > 0 and original > 0 %}
820825
{% set savings_pct = (saved / original * 100) %}
@@ -824,14 +829,14 @@ def generate(self) -> None:
824829
</td>
825830
</tr>
826831
{% endif %}
827-
{% if usage_summary.get('total_cost') is not none %}
832+
{% if usage_summary.total_cost is not none %}
828833
<tr>
829834
<th>Estimated Cost <span style="font-weight:normal;color:var(--text-muted);">(actual cost may differ)</span></th>
830835
<td>
831-
{{ "%.2f"|format(usage_summary.get('total_cost')) }} {{ usage_summary.get('currency', 'USD') }}
836+
{{ "%.2f"|format(usage_summary.total_cost) }} {{ usage_summary.currency or 'USD' }}
832837
<span style="color: var(--text-muted); margin-left: 8px; font-size: 0.85em;">
833-
(Input: ${{ "%.2f"|format(usage_summary.get('input_cost_per_million_tokens', 0)) }}/1M tokens,
834-
Output: ${{ "%.2f"|format(usage_summary.get('output_cost_per_million_tokens', 0)) }}/1M tokens)
838+
(Input: ${{ "%.2f"|format(usage_summary.input_cost_per_million_tokens or 0) }}/1M tokens,
839+
Output: ${{ "%.2f"|format(usage_summary.output_cost_per_million_tokens or 0) }}/1M tokens)
835840
</span>
836841
</td>
837842
</tr>
@@ -992,7 +997,7 @@ def add_message(
992997
)
993998

994999
@override
995-
def add_usage_summary(self, usage: dict[str, int | None]) -> None:
1000+
def add_usage_summary(self, usage: UsageSummary) -> None:
9961001
"""No-op for AllureReporter - usage is not tracked."""
9971002

9981003
@override

tests/unit/model_providers/test_model_pricing.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Unit tests for model pricing resolution and cost calculation."""
22

3-
from typing import Any
43
from unittest.mock import MagicMock
54

65
import pytest
76

87
from askui.models.shared.agent_message_param import UsageParam
9-
from askui.models.shared.usage_tracking_callback import UsageTrackingCallback
8+
from askui.models.shared.usage_tracking_callback import (
9+
UsageSummary,
10+
UsageTrackingCallback,
11+
)
1012
from askui.utils.model_pricing import ModelPricing
1113

1214

@@ -58,7 +60,7 @@ def test_override_costs_unknown_model(self) -> None:
5860
assert pricing.input_cost_per_million_tokens == 1.0
5961

6062

61-
def _get_usage_dict(reporter_mock: MagicMock) -> dict[str, Any]:
63+
def _get_usage_summary(reporter_mock: MagicMock) -> UsageSummary:
6264
return reporter_mock.add_usage_summary.call_args[0][0] # type: ignore[no-any-return]
6365

6466

@@ -82,13 +84,13 @@ def test_cost_included_when_pricing_set(self) -> None:
8284
)
8385
callback.on_conversation_end(MagicMock())
8486

85-
usage_dict = _get_usage_dict(reporter)
86-
assert usage_dict["total_cost"] == pytest.approx(4.5)
87-
assert usage_dict["input_cost"] == pytest.approx(3.0)
88-
assert usage_dict["output_cost"] == pytest.approx(1.5)
89-
assert usage_dict["currency"] == "USD"
90-
assert usage_dict["input_cost_per_million_tokens"] == 3.0
91-
assert usage_dict["output_cost_per_million_tokens"] == 15.0
87+
summary = _get_usage_summary(reporter)
88+
assert summary.total_cost == pytest.approx(4.5)
89+
assert summary.input_cost == pytest.approx(3.0)
90+
assert summary.output_cost == pytest.approx(1.5)
91+
assert summary.currency == "USD"
92+
assert summary.input_cost_per_million_tokens == 3.0
93+
assert summary.output_cost_per_million_tokens == 15.0
9294

9395
def test_no_cost_when_pricing_none(self) -> None:
9496
callback, reporter = self._make_callback(pricing=None)
@@ -98,9 +100,9 @@ def test_no_cost_when_pricing_none(self) -> None:
98100
)
99101
callback.on_conversation_end(MagicMock())
100102

101-
usage_dict = _get_usage_dict(reporter)
102-
assert "total_cost" not in usage_dict
103-
assert "currency" not in usage_dict
103+
summary = _get_usage_summary(reporter)
104+
assert summary.total_cost is None
105+
assert summary.currency is None
104106

105107
def test_zero_tokens_produce_zero_cost(self) -> None:
106108
pricing = ModelPricing(
@@ -114,8 +116,8 @@ def test_zero_tokens_produce_zero_cost(self) -> None:
114116
)
115117
callback.on_conversation_end(MagicMock())
116118

117-
usage_dict = _get_usage_dict(reporter)
118-
assert usage_dict["total_cost"] == 0.0
119+
summary = _get_usage_summary(reporter)
120+
assert summary.total_cost == 0.0
119121

120122
def test_none_tokens_treated_as_zero(self) -> None:
121123
pricing = ModelPricing(
@@ -126,8 +128,8 @@ def test_none_tokens_treated_as_zero(self) -> None:
126128
callback._accumulated_usage = UsageParam()
127129
callback.on_conversation_end(MagicMock())
128130

129-
usage_dict = _get_usage_dict(reporter)
130-
assert usage_dict["total_cost"] == 0.0
131+
summary = _get_usage_summary(reporter)
132+
assert summary.total_cost == 0.0
131133

132134
def test_cost_calculation_accuracy(self) -> None:
133135
pricing = ModelPricing(
@@ -141,11 +143,9 @@ def test_cost_calculation_accuracy(self) -> None:
141143
)
142144
callback.on_conversation_end(MagicMock())
143145

144-
usage_dict = _get_usage_dict(reporter)
146+
summary = _get_usage_summary(reporter)
145147
expected_input = 50_000 * 15.0 / 1_000_000
146148
expected_output = 10_000 * 75.0 / 1_000_000
147-
assert usage_dict["input_cost"] == pytest.approx(expected_input)
148-
assert usage_dict["output_cost"] == pytest.approx(expected_output)
149-
assert usage_dict["total_cost"] == pytest.approx(
150-
expected_input + expected_output
151-
)
149+
assert summary.input_cost == pytest.approx(expected_input)
150+
assert summary.output_cost == pytest.approx(expected_output)
151+
assert summary.total_cost == pytest.approx(expected_input + expected_output)

0 commit comments

Comments
 (0)