Skip to content

Commit 37ee558

Browse files
committed
Add support for cost tracking
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 8e0217a commit 37ee558

24 files changed

Lines changed: 1912 additions & 692 deletions

alphatrion/agents/claude.py

Lines changed: 76 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212

1313
from alphatrion.storage import runtime
1414
from alphatrion.storage.sql_models import Status
15-
from alphatrion.tracing.clickhouse_exporter import determine_semantic_kind
15+
from alphatrion.tracing.span_processor import (
16+
SEMANTIC_KIND_CHAT,
17+
SEMANTIC_KIND_REASONING,
18+
SEMANTIC_KIND_TOOL,
19+
SEMANTIC_KIND_UNKNOWN,
20+
)
21+
from alphatrion.utils.pricing import calculate_cost
1622

1723

1824
def handle_hook(hook_type: str):
@@ -548,15 +554,6 @@ def process_transcript_incremental(
548554
# Turn is complete, create run and multiple LLM spans
549555
user_message = current_user_msg.get("message", {})
550556

551-
# Calculate total tokens and duration from all messages
552-
total_input_tokens = 0
553-
total_output_tokens = 0
554-
for msg in current_assistant_messages:
555-
msg_usage = msg.get("message", {}).get("usage", {})
556-
total_input_tokens += msg_usage.get("input_tokens", 0)
557-
total_output_tokens += msg_usage.get("output_tokens", 0)
558-
total_tokens = total_input_tokens + total_output_tokens
559-
560557
# Calculate duration from timestamps
561558
# (first user message to last assistant message)
562559
duration = calculate_duration(
@@ -598,11 +595,6 @@ def process_transcript_incremental(
598595
user_id=session.user_id,
599596
status=run_status,
600597
duration=duration,
601-
usage={
602-
"input_tokens": total_input_tokens,
603-
"output_tokens": total_output_tokens,
604-
"total_tokens": total_tokens,
605-
},
606598
)
607599

608600
# Prepare user content (original user message only)
@@ -1009,8 +1001,12 @@ def create_clickhouse_spans_for_turn(
10091001
msg_content = []
10101002

10111003
msg_usage = message_data.get("usage", {})
1012-
msg_input_tokens = msg_usage.get("input_tokens", 0)
1013-
msg_output_tokens = msg_usage.get("output_tokens", 0)
1004+
input_tokens = msg_usage.get("input_tokens", 0)
1005+
output_tokens = msg_usage.get("output_tokens", 0)
1006+
cache_creation_input_tokens = msg_usage.get(
1007+
"cache_creation_input_tokens", 0
1008+
)
1009+
cache_read_input_tokens = msg_usage.get("cache_read_input_tokens", 0)
10141010

10151011
# Determine content type
10161012
tool_use_blocks = [
@@ -1083,13 +1079,17 @@ def create_clickhouse_spans_for_turn(
10831079
span_id = str(uuid.uuid4()).replace("-", "")[:16]
10841080

10851081
# Determine semantic kind
1086-
semantic_kind = determine_semantic_kind({}, msg_content)
1082+
semantic_kind = determine_semantic_kind(msg_content)
10871083

1088-
# Token assignment:
1089-
# Each message has actual token usage from Claude API
1090-
# Assign actual tokens to each span for accurate aggregation
1091-
input_tokens = msg_input_tokens
1092-
output_tokens = msg_output_tokens
1084+
# Calculate cost for this span
1085+
span_costs = calculate_cost(
1086+
provider="anthropic",
1087+
model=model,
1088+
input_tokens=input_tokens,
1089+
output_tokens=output_tokens,
1090+
cache_creation_input_tokens=cache_creation_input_tokens,
1091+
cache_read_input_tokens=cache_read_input_tokens,
1092+
)
10931093

10941094
# Build span attributes
10951095
span_attributes = {
@@ -1099,7 +1099,25 @@ def create_clickhouse_spans_for_turn(
10991099
"gen_ai.response.model": model,
11001100
"gen_ai.usage.input_tokens": str(input_tokens),
11011101
"gen_ai.usage.output_tokens": str(output_tokens),
1102-
"llm.usage.total_tokens": str(input_tokens + output_tokens),
1102+
"gen_ai.usage.cache_creation_input_tokens": str(
1103+
cache_creation_input_tokens
1104+
),
1105+
"gen_ai.usage.cache_read_input_tokens": str(cache_read_input_tokens),
1106+
"llm.usage.total_tokens": str(
1107+
input_tokens
1108+
+ output_tokens
1109+
+ cache_creation_input_tokens
1110+
+ cache_read_input_tokens
1111+
),
1112+
"alphatrion.cost.total_tokens": str(span_costs["total_cost"]),
1113+
"alphatrion.cost.input_tokens": str(span_costs["input_cost"]),
1114+
"alphatrion.cost.output_tokens": str(span_costs["output_cost"]),
1115+
"alphatrion.cost.cache_creation_input_tokens": str(
1116+
span_costs["cache_creation_input_cost"]
1117+
),
1118+
"alphatrion.cost.cache_read_input_tokens": str(
1119+
span_costs["cache_read_input_cost"]
1120+
),
11031121
}
11041122

11051123
# Add prompt to first span only (where the user input is sent)
@@ -1222,3 +1240,37 @@ def create_clickhouse_spans_for_turn(
12221240
import logging
12231241

12241242
logging.error(f"Failed to create ClickHouse spans: {e}", exc_info=True)
1243+
1244+
1245+
def determine_semantic_kind(content_blocks: list) -> str:
1246+
"""Determine semantic kind of a message based on content blocks.
1247+
1248+
- If contains tool_use block → "tool"
1249+
- Else if contains thinking block → "thinking"
1250+
- Else → "text"
1251+
1252+
Args:
1253+
content_blocks: List of content blocks in the message
1254+
1255+
Returns:
1256+
Semantic kind string
1257+
"""
1258+
has_tool_use = any(
1259+
isinstance(b, dict) and b.get("type") == "tool_use" for b in content_blocks
1260+
)
1261+
if has_tool_use:
1262+
return SEMANTIC_KIND_TOOL
1263+
1264+
has_thinking = any(
1265+
isinstance(b, dict) and b.get("type") == "thinking" for b in content_blocks
1266+
)
1267+
if has_thinking:
1268+
return SEMANTIC_KIND_REASONING
1269+
1270+
has_text = any(
1271+
isinstance(b, dict) and b.get("type") == "text" for b in content_blocks
1272+
)
1273+
if has_text:
1274+
return SEMANTIC_KIND_CHAT
1275+
1276+
return SEMANTIC_KIND_UNKNOWN

alphatrion/experiment/base.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,16 +220,9 @@ def _start(
220220
# to avoid confusion.
221221
if exp_obj and exp_obj.status != Status.COMPLETED:
222222
self._id = exp_obj.uuid
223-
usage = exp_obj.usage
224-
225-
# reset to running status, also need to reset the tokens.
226-
if usage and "total_tokens" in usage:
227-
# delete the tokens in the usage - set to None instead of empty dict
228-
usage = None
229223
self._runtime._metadb.update_experiment(
230224
experiment_id=self._id,
231225
status=Status.RUNNING,
232-
usage=usage,
233226
)
234227
elif exp_obj and exp_obj.status == Status.COMPLETED:
235228
raise RuntimeError(

alphatrion/runtime/runtime.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
# ruff: noqa: PLW0603
2+
import logging
23
import os
4+
import sys
35
import uuid
46

57
from alphatrion import envs
68
from alphatrion.storage import runtime as storage_runtime
79
from alphatrion.storage.sqlstore import SQLStore
10+
from alphatrion.utils.pricing import load_pricing_config
11+
12+
logger = logging.getLogger(__name__)
813

914
__RUNTIME__ = None
1015

@@ -57,6 +62,15 @@ def __init__(
5762
team_id: uuid.UUID | None = None,
5863
org_id: uuid.UUID | None = None,
5964
):
65+
# Load pricing config at startup - exit if it fails
66+
try:
67+
load_pricing_config()
68+
logger.info("Successfully loaded model pricing configuration")
69+
except Exception as e:
70+
logger.error(f"Failed to load pricing configuration: {e}")
71+
logger.error("Application cannot start without valid pricing configuration")
72+
sys.exit(1)
73+
6074
storage_runtime.init()
6175
self._metadb = storage_runtime.storage_runtime().metadb
6276
self._tracestore = storage_runtime.storage_runtime().tracestore

alphatrion/server/cmd/app.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,26 @@ async def receive():
109109
return response
110110

111111

112+
# Wrapper to convert context auth errors to HTTP exceptions
113+
async def get_context_with_error_handling(request: Request):
114+
"""Wrap get_context to convert auth errors to proper HTTP status codes."""
115+
try:
116+
return await get_context(request)
117+
except ValueError as e:
118+
# Authentication/authorization errors from get_context
119+
error_msg = str(e).lower()
120+
if (
121+
"authorization" in error_msg
122+
or "token" in error_msg
123+
or "missing" in error_msg
124+
):
125+
raise HTTPException(status_code=401, detail=str(e))
126+
# Other validation errors
127+
raise HTTPException(status_code=400, detail=str(e))
128+
129+
112130
# Create GraphQL router with context
113-
graphql_app = GraphQLRouter(schema, context_getter=get_context)
131+
graphql_app = GraphQLRouter(schema, context_getter=get_context_with_error_handling)
114132

115133
# Mount /graphql endpoint
116134
app.include_router(graphql_app, prefix="/graphql")

alphatrion/server/cmd/main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def init_command(args):
204204

205205
if len(password) < character_length_at_least:
206206
console.print(
207-
Text(f"❌ Error: Password must be at least {character_length_at_least} characters", style="bold red")
207+
Text(
208+
f"❌ Error: Password must be at least {character_length_at_least} characters",
209+
style="bold red",
210+
)
208211
)
209212
return
210213

0 commit comments

Comments
 (0)