Skip to content

Commit 6d17375

Browse files
chore: refactor usage tracking to integrate via callback
1 parent 1650ff7 commit 6d17375

3 files changed

Lines changed: 82 additions & 45 deletions

File tree

src/askui/agent_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
LocateSettings,
2424
)
2525
from askui.models.shared.tools import Tool, ToolCollection
26+
from askui.models.shared.usage_tracking_callback import UsageTrackingCallback
2627
from askui.prompts.act_prompts import CACHE_USE_PROMPT, create_default_prompt
2728
from askui.tools.agent_os import AgentOs
2829
from askui.tools.android.agent_os import AndroidAgentOs
@@ -73,13 +74,15 @@ def __init__(
7374

7475
# Create conversation with speakers and model providers
7576
speakers = Speakers()
77+
_callbacks = list(callbacks or [])
78+
_callbacks.append(UsageTrackingCallback(reporter=self._reporter))
7679
self._conversation = Conversation(
7780
speakers=speakers,
7881
vlm_provider=self._vlm_provider,
7982
image_qa_provider=self._image_qa_provider,
8083
detection_provider=self._detection_provider,
8184
reporter=self._reporter,
82-
callbacks=callbacks,
85+
callbacks=_callbacks,
8386
)
8487

8588
# Provider-based tools

src/askui/models/shared/conversation.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from askui.model_providers.detection_provider import DetectionProvider
1010
from askui.model_providers.image_qa_provider import ImageQAProvider
1111
from askui.model_providers.vlm_provider import VlmProvider
12-
from askui.models.shared.agent_message_param import (
13-
MessageParam,
14-
UsageParam,
15-
)
12+
from askui.models.shared.agent_message_param import MessageParam
1613
from askui.models.shared.settings import ActSettings
1714
from askui.models.shared.tools import ToolCollection
1815
from askui.models.shared.truncation_strategies import (
@@ -84,7 +81,6 @@ def __init__(
8481
# Speakers and current state
8582
self.speakers = speakers
8683
self.current_speaker = speakers[speakers.default_speaker]
87-
self.accumulated_usage = UsageParam()
8884

8985
# Model providers - accessible by speakers via conversation instance
9086
self.vlm_provider = vlm_provider
@@ -184,7 +180,6 @@ def _setup_control_loop(
184180
reporters: list[Reporter] | None = None,
185181
) -> None:
186182
# Reset state
187-
self.accumulated_usage = UsageParam()
188183
self._executed_from_cache = False
189184
self.speakers.reset_state()
190185

@@ -221,9 +216,6 @@ def _conclude_control_loop(self) -> None:
221216
if self.cache_manager is not None and not self._executed_from_cache:
222217
self.cache_manager.finish_recording(self.get_messages())
223218

224-
# Report final usage
225-
self._reporter.add_usage_summary(self.accumulated_usage.model_dump())
226-
227219
def _setup_speaker_handoff(self) -> None:
228220
"""Set up speaker handoff infrastructure.
229221
@@ -316,10 +308,6 @@ def _execute_step(self) -> bool:
316308
status_continue = self._handle_result_status(result)
317309
continue_loop = continue_loop or status_continue
318310

319-
# 5. Collect Statistics
320-
if result.usage:
321-
self._accumulate_usage(result.usage)
322-
323311
self._on_step_end(self._step_index, result)
324312
self._step_index += 1
325313

@@ -450,34 +438,3 @@ def get_truncation_strategy(self) -> TruncationStrategy | None:
450438
Current truncation strategy or None if not initialized
451439
"""
452440
return self._truncation_strategy
453-
454-
def _accumulate_usage(self, step_usage: UsageParam) -> None:
455-
"""Accumulate token usage statistics.
456-
457-
Args:
458-
step_usage: Usage from a single step
459-
"""
460-
self.accumulated_usage.input_tokens = (
461-
self.accumulated_usage.input_tokens or 0
462-
) + (step_usage.input_tokens or 0)
463-
self.accumulated_usage.output_tokens = (
464-
self.accumulated_usage.output_tokens or 0
465-
) + (step_usage.output_tokens or 0)
466-
self.accumulated_usage.cache_creation_input_tokens = (
467-
self.accumulated_usage.cache_creation_input_tokens or 0
468-
) + (step_usage.cache_creation_input_tokens or 0)
469-
self.accumulated_usage.cache_read_input_tokens = (
470-
self.accumulated_usage.cache_read_input_tokens or 0
471-
) + (step_usage.cache_read_input_tokens or 0)
472-
473-
current_span = trace.get_current_span()
474-
current_span.set_attributes(
475-
{
476-
"input_tokens": step_usage.input_tokens or 0,
477-
"output_tokens": step_usage.output_tokens or 0,
478-
"cache_creation_input_tokens": (
479-
step_usage.cache_creation_input_tokens or 0
480-
),
481-
"cache_read_input_tokens": step_usage.cache_read_input_tokens or 0,
482-
}
483-
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Callback for tracking token usage and reporting usage summaries."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
from opentelemetry import trace
8+
from typing_extensions import override
9+
10+
from askui.models.shared.agent_message_param import UsageParam
11+
from askui.models.shared.conversation_callback import ConversationCallback
12+
from askui.reporting import NULL_REPORTER, Reporter
13+
14+
if TYPE_CHECKING:
15+
from askui.models.shared.conversation import Conversation
16+
from askui.speaker.speaker import SpeakerResult
17+
18+
19+
class UsageTrackingCallback(ConversationCallback):
20+
"""Tracks token usage per step and reports a summary at conversation end.
21+
22+
Args:
23+
reporter: Reporter to write the final usage summary to.
24+
"""
25+
26+
def __init__(self, reporter: Reporter = NULL_REPORTER) -> None:
27+
self._reporter = reporter
28+
self._accumulated_usage = UsageParam()
29+
30+
@override
31+
def on_conversation_start(self, conversation: Conversation) -> None:
32+
self._accumulated_usage = UsageParam()
33+
34+
@override
35+
def on_step_end(
36+
self,
37+
conversation: Conversation,
38+
step_index: int,
39+
result: SpeakerResult,
40+
) -> None:
41+
if result.usage:
42+
self._accumulate(result.usage)
43+
44+
@override
45+
def on_conversation_end(self, conversation: Conversation) -> None:
46+
self._reporter.add_usage_summary(self._accumulated_usage.model_dump())
47+
48+
@property
49+
def accumulated_usage(self) -> UsageParam:
50+
"""Current accumulated usage statistics."""
51+
return self._accumulated_usage
52+
53+
def _accumulate(self, step_usage: UsageParam) -> None:
54+
self._accumulated_usage.input_tokens = (
55+
self._accumulated_usage.input_tokens or 0
56+
) + (step_usage.input_tokens or 0)
57+
self._accumulated_usage.output_tokens = (
58+
self._accumulated_usage.output_tokens or 0
59+
) + (step_usage.output_tokens or 0)
60+
self._accumulated_usage.cache_creation_input_tokens = (
61+
self._accumulated_usage.cache_creation_input_tokens or 0
62+
) + (step_usage.cache_creation_input_tokens or 0)
63+
self._accumulated_usage.cache_read_input_tokens = (
64+
self._accumulated_usage.cache_read_input_tokens or 0
65+
) + (step_usage.cache_read_input_tokens or 0)
66+
67+
current_span = trace.get_current_span()
68+
current_span.set_attributes(
69+
{
70+
"input_tokens": step_usage.input_tokens or 0,
71+
"output_tokens": step_usage.output_tokens or 0,
72+
"cache_creation_input_tokens": (
73+
step_usage.cache_creation_input_tokens or 0
74+
),
75+
"cache_read_input_tokens": (step_usage.cache_read_input_tokens or 0),
76+
}
77+
)

0 commit comments

Comments
 (0)