Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,5 @@ reports/
.askui_cache/*

bom.json

*playground*
4,213 changes: 1,933 additions & 2,280 deletions pdm.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ authors = [
]
dependencies = [
"askui-agent-os>=26.1.1",
"anthropic>=0.72.0",
"anthropic>=0.86.0",
"fastapi>=0.115.12",
"fastmcp>=2.3.0",
"gradio-client>=1.4.3",
"grpcio>=1.73.1",
"grpcio>=1.73.1,<1.80.0",
"httpx>=0.28.1",
"Jinja2>=3.1.4",
"openai>=1.61.1",
Expand Down
283 changes: 217 additions & 66 deletions src/askui/callbacks/usage_tracking_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import TYPE_CHECKING

from opentelemetry import trace
from pydantic import BaseModel
from typing_extensions import override
from pydantic import BaseModel, Field
from typing_extensions import Self, override

from askui.callbacks.conversation_callback import ConversationCallback
from askui.reporting import NULL_REPORTER
Expand All @@ -18,6 +18,8 @@
from askui.speaker.speaker import SpeakerResult
from askui.utils.model_pricing import ModelPricing

_USD_CURRENCY = "USD"


class UsageSummary(BaseModel):
"""Accumulated token usage and optional cost breakdown for a conversation.
Expand All @@ -27,9 +29,13 @@ class UsageSummary(BaseModel):
output_tokens (int | None): Total output tokens generated.
cache_creation_input_tokens (int | None): Tokens used for cache creation.
cache_read_input_tokens (int | None): Tokens read from cache.
input_cost (float | None): Computed input cost in `currency`.
output_cost (float | None): Computed output cost in `currency`.
total_cost (float | None): Sum of `input_cost` and `output_cost`.
input_token_cost (float | None): Computed cost for input tokens in `currency`.
output_token_cost (float | None): Computed cost for output tokens in `currency`.
cache_write_token_cost (float | None): Computed cost for cache write tokens in
`currency`.
cache_read_token_cost (float | None): Computed cost for cache read tokens in
`currency`.
total_cost (float | None): Sum of all computed cost values.
currency (str | None): ISO 4217 currency code (e.g. ``"USD"``).
input_cost_per_million_tokens (float | None): Rate used to compute `input_cost`.
output_cost_per_million_tokens (float|None): Rate used to compute `output_cost`.
Expand All @@ -39,12 +45,138 @@ class UsageSummary(BaseModel):
output_tokens: int | None = None
cache_creation_input_tokens: int | None = None
cache_read_input_tokens: int | None = None
input_cost: float | None = None
output_cost: float | None = None
input_token_cost: float | None = None
output_token_cost: float | None = None
cache_write_token_cost: float | None = None
cache_read_token_cost: float | None = None
total_cost: float | None = None
currency: str | None = None
input_cost_per_million_tokens: float | None = None
output_cost_per_million_tokens: float | None = None
cache_write_cost_per_million_tokens: float | None = None
cache_read_cost_per_million_tokens: float | None = None
per_conversation_summaries: list[ConversationUsageSummary] | None = None

@classmethod
def create(cls, pricing: ModelPricing | None = None) -> "UsageSummary":
"""Create a summary configured with optional model pricing."""
if pricing is None:
return cls()
return cls(
input_cost_per_million_tokens=pricing.input_cost_per_million_tokens,
output_cost_per_million_tokens=pricing.output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=(
pricing.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
pricing.cache_read_cost_per_million_tokens
),
)

@classmethod
def create_from(cls, summary: "UsageSummary") -> "UsageSummary":
"""Create a new summary that reuses pricing fields from `summary`."""
return cls(
input_cost_per_million_tokens=summary.input_cost_per_million_tokens,
output_cost_per_million_tokens=summary.output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=(
summary.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
summary.cache_read_cost_per_million_tokens
),
)

def add_usage(self, usage: UsageParam) -> None:
"""Add token counts from `usage`."""
self.input_tokens = (self.input_tokens or 0) + (usage.input_tokens or 0)
self.output_tokens = (self.output_tokens or 0) + (usage.output_tokens or 0)
self.cache_creation_input_tokens = (self.cache_creation_input_tokens or 0) + (
usage.cache_creation_input_tokens or 0
)
self.cache_read_input_tokens = (self.cache_read_input_tokens or 0) + (
usage.cache_read_input_tokens or 0
)

def generate(self) -> Self:
"""Compute and populate cost fields from current token and pricing fields."""
if not self._has_pricing():
self._clear_cost_fields()
return self

input_tokens = self.input_tokens or 0
output_tokens = self.output_tokens or 0
cache_write_tokens = self.cache_creation_input_tokens or 0
cache_read_tokens = self.cache_read_input_tokens or 0

assert self.input_cost_per_million_tokens is not None
assert self.output_cost_per_million_tokens is not None
assert self.cache_write_cost_per_million_tokens is not None
assert self.cache_read_cost_per_million_tokens is not None

self.input_token_cost = self._calculate_cost(
input_tokens, self.input_cost_per_million_tokens
)
self.output_token_cost = self._calculate_cost(
output_tokens, self.output_cost_per_million_tokens
)
self.cache_write_token_cost = self._calculate_cost(
cache_write_tokens, self.cache_write_cost_per_million_tokens
)
self.cache_read_token_cost = self._calculate_cost(
cache_read_tokens, self.cache_read_cost_per_million_tokens
)
self.total_cost = (
(self.input_token_cost or 0.0)
+ (self.output_token_cost or 0.0)
+ (self.cache_write_token_cost or 0.0)
+ (self.cache_read_token_cost or 0.0)
)
self.currency = _USD_CURRENCY
return self

def token_attributes(self) -> dict[str, int]:
"""Return token fields for telemetry attributes."""
return {
"input_tokens": self.input_tokens or 0,
"output_tokens": self.output_tokens or 0,
"cache_creation_input_tokens": self.cache_creation_input_tokens or 0,
"cache_read_input_tokens": self.cache_read_input_tokens or 0,
}

def _has_pricing(self) -> bool:
return (
self.input_cost_per_million_tokens is not None
and self.output_cost_per_million_tokens is not None
and self.cache_write_cost_per_million_tokens is not None
and self.cache_read_cost_per_million_tokens is not None
)

def _clear_cost_fields(self) -> None:
self.input_token_cost = None
self.output_token_cost = None
self.cache_write_token_cost = None
self.cache_read_token_cost = None
self.total_cost = None
self.currency = None

@staticmethod
def _calculate_cost(tokens: int, rate_per_million_tokens: float) -> float:
return rate_per_million_tokens * tokens / 1e6


class StepUsageSummary(UsageSummary):
"""Usage summary for a single step."""

step_index: int


class ConversationUsageSummary(UsageSummary):
"""Usage summary for one conversation including per-step breakdown."""

conversation_index: int
conversation_id: str
step_summaries: list[StepUsageSummary] = Field(default_factory=list)


class UsageTrackingCallback(ConversationCallback):
Expand All @@ -62,12 +194,17 @@ def __init__(
pricing: ModelPricing | None = None,
) -> None:
self._reporter = reporter
self._pricing = pricing
self._summary = UsageSummary()
self._summary: UsageSummary = UsageSummary.create(pricing)
self._per_conversation_usage: UsageSummary = UsageSummary.create(pricing)
self._per_conversation_summaries: list[ConversationUsageSummary] = []
self._per_step_summaries: list[StepUsageSummary] = []
self._conversation_index: int = 0

@override
def on_conversation_start(self, conversation: Conversation) -> None:
self._summary = UsageSummary()
self._per_conversation_usage = UsageSummary.create_from(self._summary)
self._per_step_summaries = []
self._conversation_index += 1

@override
def on_step_end(
Expand All @@ -76,71 +213,85 @@ def on_step_end(
step_index: int,
result: SpeakerResult,
) -> None:
if result.usage:
self._accumulate(result.usage)
step_usage: UsageParam | None = result.usage
if step_usage is None:
return

step_summary = self._create_step_summary(
step_index=step_index, usage=step_usage
)
self._per_step_summaries.append(step_summary)
self._per_conversation_usage.add_usage(step_usage)
self._summary.add_usage(step_usage)

current_span = trace.get_current_span()
current_span.set_attributes(step_summary.token_attributes())

@override
def on_conversation_end(self, conversation: Conversation) -> None:
self._reporter.add_usage_summary(self._summary)
generated_steps: list[StepUsageSummary] = [
step_summary.generate() for step_summary in self._per_step_summaries
]
conversation_summary = self._create_conversation_summary(
conversation=conversation,
generated_step_summaries=generated_steps,
)
self._per_conversation_summaries.append(conversation_summary)
self._summary.per_conversation_summaries = list(
self._per_conversation_summaries
)
self._reporter.add_usage_summary(self._summary.generate().model_copy(deep=True))

@property
def accumulated_usage(self) -> UsageSummary:
"""Current accumulated usage statistics."""
return self._summary

def _accumulate(self, step_usage: UsageParam) -> None:
# Add step tokens to running totals (None counts as 0)
self._summary.input_tokens = (self._summary.input_tokens or 0) + (
step_usage.input_tokens or 0
)
self._summary.output_tokens = (self._summary.output_tokens or 0) + (
step_usage.output_tokens or 0
)
self._summary.cache_creation_input_tokens = (
self._summary.cache_creation_input_tokens or 0
) + (step_usage.cache_creation_input_tokens or 0)
self._summary.cache_read_input_tokens = (
self._summary.cache_read_input_tokens or 0
) + (step_usage.cache_read_input_tokens or 0)

# Record per-step token counts on the current OTel span
current_span = trace.get_current_span()
current_span.set_attributes(
{
"input_tokens": step_usage.input_tokens or 0,
"output_tokens": step_usage.output_tokens or 0,
"cache_creation_input_tokens": (
step_usage.cache_creation_input_tokens or 0
),
"cache_read_input_tokens": (step_usage.cache_read_input_tokens or 0),
}
def _create_step_summary(
self, step_index: int, usage: UsageParam
) -> StepUsageSummary:
return StepUsageSummary(
step_index=step_index,
input_tokens=usage.input_tokens or 0,
output_tokens=usage.output_tokens or 0,
cache_creation_input_tokens=usage.cache_creation_input_tokens or 0,
cache_read_input_tokens=usage.cache_read_input_tokens or 0,
input_cost_per_million_tokens=self._summary.input_cost_per_million_tokens,
output_cost_per_million_tokens=self._summary.output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=(
self._summary.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
self._summary.cache_read_cost_per_million_tokens
),
)

# Update costs from updated totals if pricing values are set
if not (
self._pricing
and self._pricing.input_cost_per_million_tokens
and self._pricing.output_cost_per_million_tokens
):
return

input_cost = (
self._summary.input_tokens
* self._pricing.input_cost_per_million_tokens
/ 1e6
)
output_cost = (
self._summary.output_tokens
* self._pricing.output_cost_per_million_tokens
/ 1e6
)
self._summary.input_cost = input_cost
self._summary.output_cost = output_cost
self._summary.total_cost = input_cost + output_cost
self._summary.currency = self._pricing.currency
self._summary.input_cost_per_million_tokens = (
self._pricing.input_cost_per_million_tokens
)
self._summary.output_cost_per_million_tokens = (
self._pricing.output_cost_per_million_tokens
def _create_conversation_summary(
self,
conversation: Conversation,
generated_step_summaries: list[StepUsageSummary],
) -> ConversationUsageSummary:
conversation_summary = ConversationUsageSummary(
conversation_index=self._conversation_index,
conversation_id=conversation.conversation_id,
step_summaries=generated_step_summaries,
input_tokens=self._per_conversation_usage.input_tokens,
output_tokens=self._per_conversation_usage.output_tokens,
cache_creation_input_tokens=(
self._per_conversation_usage.cache_creation_input_tokens
),
cache_read_input_tokens=self._per_conversation_usage.cache_read_input_tokens,
input_cost_per_million_tokens=(
self._per_conversation_usage.input_cost_per_million_tokens
),
output_cost_per_million_tokens=(
self._per_conversation_usage.output_cost_per_million_tokens
),
cache_write_cost_per_million_tokens=(
self._per_conversation_usage.cache_write_cost_per_million_tokens
),
cache_read_cost_per_million_tokens=(
self._per_conversation_usage.cache_read_cost_per_million_tokens
),
)
return conversation_summary.generate()
12 changes: 10 additions & 2 deletions src/askui/model_providers/anthropic_vlm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ class AnthropicVlmProvider(VlmProvider):
client (Anthropic | None, optional): Pre-configured Anthropic client.
If provided, other connection parameters are ignored.
input_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M input tokens. Both cost params must be set
to override the built-in defaults.
cost in USD per 1M input tokens. All override pricing params must be set to
override the built-in defaults.
output_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M output tokens.
cache_write_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M cache write input tokens.
cache_read_cost_per_million_tokens (float | None, optional): Override
cost in USD per 1M cache read input tokens.

Example:
```python
Expand All @@ -68,6 +72,8 @@ def __init__(
client: Anthropic | None = None,
input_cost_per_million_tokens: float | None = None,
output_cost_per_million_tokens: float | None = None,
cache_write_cost_per_million_tokens: float | None = None,
cache_read_cost_per_million_tokens: float | None = None,
) -> None:
self._model_id_value = (
model_id or os.environ.get("VLM_PROVIDER_MODEL_ID") or _DEFAULT_MODEL_ID
Expand All @@ -84,6 +90,8 @@ def __init__(
self._model_id_value,
input_cost_per_million_tokens=input_cost_per_million_tokens,
output_cost_per_million_tokens=output_cost_per_million_tokens,
cache_write_cost_per_million_tokens=cache_write_cost_per_million_tokens,
cache_read_cost_per_million_tokens=cache_read_cost_per_million_tokens,
)

@property
Expand Down
Loading
Loading