Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 76 additions & 24 deletions alphatrion/agents/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@

from alphatrion.storage import runtime
from alphatrion.storage.sql_models import Status
from alphatrion.tracing.clickhouse_exporter import determine_semantic_kind
from alphatrion.tracing.span_processor import (
SEMANTIC_KIND_CHAT,
SEMANTIC_KIND_REASONING,
SEMANTIC_KIND_TOOL,
SEMANTIC_KIND_UNKNOWN,
)
from alphatrion.utils.pricing import calculate_cost


def handle_hook(hook_type: str):
Expand Down Expand Up @@ -548,15 +554,6 @@ def process_transcript_incremental(
# Turn is complete, create run and multiple LLM spans
user_message = current_user_msg.get("message", {})

# Calculate total tokens and duration from all messages
total_input_tokens = 0
total_output_tokens = 0
for msg in current_assistant_messages:
msg_usage = msg.get("message", {}).get("usage", {})
total_input_tokens += msg_usage.get("input_tokens", 0)
total_output_tokens += msg_usage.get("output_tokens", 0)
total_tokens = total_input_tokens + total_output_tokens

# Calculate duration from timestamps
# (first user message to last assistant message)
duration = calculate_duration(
Expand Down Expand Up @@ -598,11 +595,6 @@ def process_transcript_incremental(
user_id=session.user_id,
status=run_status,
duration=duration,
usage={
"input_tokens": total_input_tokens,
"output_tokens": total_output_tokens,
"total_tokens": total_tokens,
},
)

# Prepare user content (original user message only)
Expand Down Expand Up @@ -1009,8 +1001,12 @@ def create_clickhouse_spans_for_turn(
msg_content = []

msg_usage = message_data.get("usage", {})
msg_input_tokens = msg_usage.get("input_tokens", 0)
msg_output_tokens = msg_usage.get("output_tokens", 0)
input_tokens = msg_usage.get("input_tokens", 0)
output_tokens = msg_usage.get("output_tokens", 0)
cache_creation_input_tokens = msg_usage.get(
"cache_creation_input_tokens", 0
)
cache_read_input_tokens = msg_usage.get("cache_read_input_tokens", 0)

# Determine content type
tool_use_blocks = [
Expand Down Expand Up @@ -1083,13 +1079,17 @@ def create_clickhouse_spans_for_turn(
span_id = str(uuid.uuid4()).replace("-", "")[:16]

# Determine semantic kind
semantic_kind = determine_semantic_kind({}, msg_content)
semantic_kind = determine_semantic_kind(msg_content)

# Token assignment:
# Each message has actual token usage from Claude API
# Assign actual tokens to each span for accurate aggregation
input_tokens = msg_input_tokens
output_tokens = msg_output_tokens
# Calculate cost for this span
span_costs = calculate_cost(
provider="anthropic",
model=model,
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_creation_input_tokens=cache_creation_input_tokens,
cache_read_input_tokens=cache_read_input_tokens,
)

# Build span attributes
span_attributes = {
Expand All @@ -1099,7 +1099,25 @@ def create_clickhouse_spans_for_turn(
"gen_ai.response.model": model,
"gen_ai.usage.input_tokens": str(input_tokens),
"gen_ai.usage.output_tokens": str(output_tokens),
"llm.usage.total_tokens": str(input_tokens + output_tokens),
"gen_ai.usage.cache_creation_input_tokens": str(
cache_creation_input_tokens
),
"gen_ai.usage.cache_read_input_tokens": str(cache_read_input_tokens),
"llm.usage.total_tokens": str(
input_tokens
+ output_tokens
+ cache_creation_input_tokens
+ cache_read_input_tokens
),
"alphatrion.cost.total_tokens": str(span_costs["total_cost"]),
"alphatrion.cost.input_tokens": str(span_costs["input_cost"]),
"alphatrion.cost.output_tokens": str(span_costs["output_cost"]),
"alphatrion.cost.cache_creation_input_tokens": str(
span_costs["cache_creation_input_cost"]
),
"alphatrion.cost.cache_read_input_tokens": str(
span_costs["cache_read_input_cost"]
),
}

# Add prompt to first span only (where the user input is sent)
Expand Down Expand Up @@ -1222,3 +1240,37 @@ def create_clickhouse_spans_for_turn(
import logging

logging.error(f"Failed to create ClickHouse spans: {e}", exc_info=True)


def determine_semantic_kind(content_blocks: list) -> str:
"""Determine semantic kind of a message based on content blocks.

- If contains tool_use block → "tool"
- Else if contains thinking block → "thinking"
- Else → "text"

Args:
content_blocks: List of content blocks in the message

Returns:
Semantic kind string
"""
has_tool_use = any(
isinstance(b, dict) and b.get("type") == "tool_use" for b in content_blocks
)
if has_tool_use:
return SEMANTIC_KIND_TOOL

has_thinking = any(
isinstance(b, dict) and b.get("type") == "thinking" for b in content_blocks
)
if has_thinking:
return SEMANTIC_KIND_REASONING

has_text = any(
isinstance(b, dict) and b.get("type") == "text" for b in content_blocks
)
if has_text:
return SEMANTIC_KIND_CHAT

return SEMANTIC_KIND_UNKNOWN
7 changes: 0 additions & 7 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,9 @@ def _start(
# to avoid confusion.
if exp_obj and exp_obj.status != Status.COMPLETED:
self._id = exp_obj.uuid
usage = exp_obj.usage

# reset to running status, also need to reset the tokens.
if usage and "total_tokens" in usage:
# delete the tokens in the usage - set to None instead of empty dict
usage = None
self._runtime._metadb.update_experiment(
experiment_id=self._id,
status=Status.RUNNING,
usage=usage,
)
elif exp_obj and exp_obj.status == Status.COMPLETED:
raise RuntimeError(
Expand Down
14 changes: 14 additions & 0 deletions alphatrion/runtime/runtime.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# ruff: noqa: PLW0603
import logging
import os
import sys
import uuid

from alphatrion import envs
from alphatrion.storage import runtime as storage_runtime
from alphatrion.storage.sqlstore import SQLStore
from alphatrion.utils.pricing import load_pricing_config

logger = logging.getLogger(__name__)

__RUNTIME__ = None

Expand Down Expand Up @@ -57,6 +62,15 @@ def __init__(
team_id: uuid.UUID | None = None,
org_id: uuid.UUID | None = None,
):
# Load pricing config at startup - exit if it fails
try:
load_pricing_config()
logger.info("Successfully loaded model pricing configuration")
except Exception as e:
logger.error(f"Failed to load pricing configuration: {e}")
logger.error("Application cannot start without valid pricing configuration")
sys.exit(1)

storage_runtime.init()
self._metadb = storage_runtime.storage_runtime().metadb
self._tracestore = storage_runtime.storage_runtime().tracestore
Expand Down
20 changes: 19 additions & 1 deletion alphatrion/server/cmd/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,26 @@ async def receive():
return response


# Wrapper to convert context auth errors to HTTP exceptions
async def get_context_with_error_handling(request: Request):
"""Wrap get_context to convert auth errors to proper HTTP status codes."""
try:
return await get_context(request)
except ValueError as e:
# Authentication/authorization errors from get_context
error_msg = str(e).lower()
if (
"authorization" in error_msg
or "token" in error_msg
or "missing" in error_msg
):
raise HTTPException(status_code=401, detail=str(e))
# Other validation errors
raise HTTPException(status_code=400, detail=str(e))


# Create GraphQL router with context
graphql_app = GraphQLRouter(schema, context_getter=get_context)
graphql_app = GraphQLRouter(schema, context_getter=get_context_with_error_handling)

# Mount /graphql endpoint
app.include_router(graphql_app, prefix="/graphql")
Expand Down
5 changes: 4 additions & 1 deletion alphatrion/server/cmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def init_command(args):

if len(password) < character_length_at_least:
console.print(
Text(f"❌ Error: Password must be at least {character_length_at_least} characters", style="bold red")
Text(
f"❌ Error: Password must be at least {character_length_at_least} characters",
style="bold red",
)
)
return

Expand Down
Loading
Loading