1+ from __future__ import annotations
2+
13import base64
24import io
35import json
911from importlib .metadata import distributions
1012from io import BytesIO
1113from pathlib import Path
12- from typing import Any , Optional , Union
14+ from typing import TYPE_CHECKING , Any , Optional , Union
1315
1416from jinja2 import Template
15- from PIL import Image
1617from typing_extensions import TypedDict , override
1718
1819from askui .utils .annotated_image import AnnotatedImage
1920
21+ if TYPE_CHECKING :
22+ from PIL import Image
23+
24+ from askui .models .shared .usage_tracking_callback import UsageSummary
25+
2026
2127def normalize_to_pil_images (
2228 image : Image .Image | list [Image .Image ] | AnnotatedImage | None ,
@@ -80,15 +86,14 @@ def add_message(
8086 raise NotImplementedError
8187
8288 @abstractmethod
83- def add_usage_summary (self , usage : dict [ str , int | None ] ) -> None :
89+ def add_usage_summary (self , usage : UsageSummary ) -> None :
8490 """Add usage statistics summary to the report.
8591
86- Called at the end of an act() execution with accumulated token usage.
92+ Called at the end of an ``act()`` execution with accumulated token
93+ usage and optional cost breakdown.
8794
8895 Args:
89- usage (dict[str, int | None]): Accumulated usage statistics containing:
90- - input_tokens: Total input tokens sent to API
91- - output_tokens: Total output tokens generated
96+ usage (UsageSummary): Accumulated usage statistics.
9297 """
9398 raise NotImplementedError
9499
@@ -134,7 +139,7 @@ def add_message(
134139 pass
135140
136141 @override
137- def add_usage_summary (self , usage : dict [ str , int | None ] ) -> None :
142+ def add_usage_summary (self , usage : UsageSummary ) -> None :
138143 pass
139144
140145 @override
@@ -177,7 +182,7 @@ def add_message(
177182 reporter .add_message (role , content , image )
178183
179184 @override
180- def add_usage_summary (self , usage : dict [ str , int | None ] ) -> None :
185+ def add_usage_summary (self , usage : UsageSummary ) -> None :
181186 """Add usage summary to all reporters."""
182187 for reporter in self ._reporters :
183188 reporter .add_usage_summary (usage )
@@ -215,7 +220,7 @@ def __init__(self, report_dir: str = "reports") -> None:
215220 self .report_dir = Path (report_dir )
216221 self .messages : list [dict [str , Any ]] = []
217222 self .system_info = self ._collect_system_info ()
218- self .usage_summary : dict [ str , int | None ] | None = None
223+ self .usage_summary : UsageSummary | None = None
219224 self .cache_original_usage : dict [str , int | None ] | None = None
220225 self ._start_time : datetime | None = None
221226
@@ -264,7 +269,7 @@ def add_message(
264269 self .messages .append (message )
265270
266271 @override
267- def add_usage_summary (self , usage : dict [ str , int | None ] ) -> None :
272+ def add_usage_summary (self , usage : UsageSummary ) -> None :
268273 """Store usage summary for inclusion in the report."""
269274 self .usage_summary = usage
270275
@@ -790,14 +795,14 @@ def generate(self) -> None:
790795 </tr>
791796 {% endif %}
792797 {% if usage_summary is not none %}
793- {% if usage_summary.get(' input_tokens') is not none %}
798+ {% if usage_summary.input_tokens is not none %}
794799 <tr>
795800 <th>Input Tokens</th>
796801 <td>
797- {{ "{:,}".format(usage_summary.get(' input_tokens') ) }}
802+ {{ "{:,}".format(usage_summary.input_tokens) }}
798803 {% if cache_original_usage and cache_original_usage.get('input_tokens') %}
799804 {% set original = cache_original_usage.get('input_tokens') %}
800- {% set current = usage_summary.get(' input_tokens') %}
805+ {% set current = usage_summary.input_tokens %}
801806 {% set saved = original - current %}
802807 {% if saved > 0 and original > 0 %}
803808 {% set savings_pct = (saved / original * 100) %}
@@ -807,14 +812,14 @@ def generate(self) -> None:
807812 </td>
808813 </tr>
809814 {% endif %}
810- {% if usage_summary.get(' output_tokens') is not none %}
815+ {% if usage_summary.output_tokens is not none %}
811816 <tr>
812817 <th>Output Tokens</th>
813818 <td>
814- {{ "{:,}".format(usage_summary.get(' output_tokens') ) }}
819+ {{ "{:,}".format(usage_summary.output_tokens) }}
815820 {% if cache_original_usage and cache_original_usage.get('output_tokens') %}
816821 {% set original = cache_original_usage.get('output_tokens') %}
817- {% set current = usage_summary.get(' output_tokens') %}
822+ {% set current = usage_summary.output_tokens %}
818823 {% set saved = original - current %}
819824 {% if saved > 0 and original > 0 %}
820825 {% set savings_pct = (saved / original * 100) %}
@@ -824,14 +829,14 @@ def generate(self) -> None:
824829 </td>
825830 </tr>
826831 {% endif %}
827- {% if usage_summary.get(' total_cost') is not none %}
832+ {% if usage_summary.total_cost is not none %}
828833 <tr>
829834 <th>Estimated Cost <span style="font-weight:normal;color:var(--text-muted);">(actual cost may differ)</span></th>
830835 <td>
831- {{ "%.2f"|format(usage_summary.get(' total_cost')) }} {{ usage_summary.get(' currency', 'USD') }}
836+ {{ "%.2f"|format(usage_summary.total_cost) }} {{ usage_summary.currency or 'USD' }}
832837 <span style="color: var(--text-muted); margin-left: 8px; font-size: 0.85em;">
833- (Input: ${{ "%.2f"|format(usage_summary.get(' input_cost_per_million_tokens', 0) ) }}/1M tokens,
834- Output: ${{ "%.2f"|format(usage_summary.get(' output_cost_per_million_tokens', 0) ) }}/1M tokens)
838+ (Input: ${{ "%.2f"|format(usage_summary.input_cost_per_million_tokens or 0 ) }}/1M tokens,
839+ Output: ${{ "%.2f"|format(usage_summary.output_cost_per_million_tokens or 0 ) }}/1M tokens)
835840 </span>
836841 </td>
837842 </tr>
@@ -992,7 +997,7 @@ def add_message(
992997 )
993998
994999 @override
995- def add_usage_summary (self , usage : dict [ str , int | None ] ) -> None :
1000+ def add_usage_summary (self , usage : UsageSummary ) -> None :
9961001 """No-op for AllureReporter - usage is not tracked."""
9971002
9981003 @override
0 commit comments