Skip to content

Commit e31160c

Browse files
authored
Persist the tokens for the experiment (#195)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent e73ca1b commit e31160c

11 files changed

Lines changed: 243 additions & 66 deletions

File tree

Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ test-integration: lint
5151
docker-compose -f ./docker-compose.yaml up -d; \
5252
trap "docker-compose -f ./docker-compose.yaml down" EXIT; \
5353
until docker exec postgres pg_isready -U alphatr1on; do sleep 1; done; \
54+
until docker exec clickhouse clickhouse-client --query "SELECT 1"; do sleep 1; done; \
5455
until curl -sf http://localhost:11434/api/tags | grep "smollm:135m" > /dev/null; do sleep 1; done; \
5556
$(PYTEST) tests/integration --timeout=30; \
5657
'

alphatrion/experiment/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,18 @@ def _start(
219219
# to avoid confusion.
220220
if exp_obj and exp_obj.status != Status.COMPLETED:
221221
self._id = exp_obj.uuid
222-
# reset to running status.
222+
usage = exp_obj.usage
223+
224+
# reset to running status, also need to reset the tokens.
225+
if usage and "total_tokens" in usage:
226+
# delete the tokens in the usage
227+
usage.delete("total_tokens")
228+
usage.delete("input_tokens")
229+
usage.delete("output_tokens")
223230
self._runtime._metadb.update_experiment(
224231
experiment_id=self._id,
225232
status=Status.RUNNING,
233+
usage=usage,
226234
)
227235
elif exp_obj and exp_obj.status == Status.COMPLETED:
228236
raise RuntimeError(
@@ -369,6 +377,8 @@ def is_done(self) -> bool:
369377
return self._context.cancelled()
370378

371379
def done(self):
380+
if self.is_done():
381+
return
372382
self._cancel()
373383

374384
def done_with_err(self):

alphatrion/server/graphql/resolvers.py

Lines changed: 99 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from alphatrion import envs
99
from alphatrion.artifact import artifact
1010
from alphatrion.storage import runtime
11-
from alphatrion.storage.sql_models import Status
11+
from alphatrion.storage.sql_models import (
12+
FINISHED_STATUS,
13+
Status,
14+
)
1215

1316
from .types import (
1417
AddUserToTeamInput,
@@ -138,6 +141,7 @@ def list_experiments(
138141
duration=e.duration,
139142
status=GraphQLStatusEnum[Status(e.status).name],
140143
kind=GraphQLExperimentTypeEnum[GraphQLExperimentType(e.kind).name],
144+
cost=e.cost,
141145
created_at=e.created_at,
142146
updated_at=e.updated_at,
143147
)
@@ -160,6 +164,7 @@ def get_experiment(id: strawberry.ID) -> Experiment | None:
160164
duration=exp.duration,
161165
status=GraphQLStatusEnum[Status(exp.status).name],
162166
kind=GraphQLExperimentTypeEnum[GraphQLExperimentType(exp.kind).name],
167+
cost=exp.cost,
163168
created_at=exp.created_at,
164169
updated_at=exp.updated_at,
165170
)
@@ -175,7 +180,7 @@ def list_runs(
175180
) -> list[Run]:
176181
metadb = runtime.storage_runtime().metadb
177182
runs = metadb.list_runs_by_exp_id(
178-
exp_id=uuid.UUID(experiment_id),
183+
experiment_id=uuid.UUID(experiment_id),
179184
page=page,
180185
page_size=page_size,
181186
order_by=order_by,
@@ -190,6 +195,7 @@ def list_runs(
190195
meta=r.meta,
191196
status=GraphQLStatusEnum[Status(r.status).name],
192197
duration=r.duration,
198+
cost=r.cost,
193199
created_at=r.created_at,
194200
)
195201
for r in runs
@@ -208,6 +214,7 @@ def get_run(id: strawberry.ID) -> Run | None:
208214
meta=run.meta,
209215
status=GraphQLStatusEnum[Status(run.status).name],
210216
duration=run.duration,
217+
cost=run.cost,
211218
created_at=run.created_at,
212219
)
213220
return None
@@ -311,6 +318,7 @@ def list_exps_by_timeframe(
311318
duration=e.duration,
312319
status=GraphQLStatusEnum[Status(e.status).name],
313320
kind=GraphQLExperimentTypeEnum[GraphQLExperimentType(e.kind).name],
321+
cost=e.cost,
314322
created_at=e.created_at,
315323
updated_at=e.updated_at,
316324
)
@@ -396,30 +404,22 @@ def aggregate_run_tokens(run_id: strawberry.ID) -> dict[str, int]:
396404
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
397405

398406
try:
399-
trace_store = runtime.storage_runtime().tracestore
400-
spans = trace_store.get_llm_spans_by_run_id(run_id)
401-
# Don't close - it's a shared singleton connection
402-
403-
total_tokens = 0
404-
input_tokens = 0
405-
output_tokens = 0
406-
407-
for span in spans:
408-
span_attrs = span.get("SpanAttributes", {})
409-
410-
# Aggregate tokens from LLM spans
411-
if "llm.usage.total_tokens" in span_attrs:
412-
total_tokens += int(span_attrs["llm.usage.total_tokens"])
413-
if "gen_ai.usage.input_tokens" in span_attrs:
414-
input_tokens += int(span_attrs["gen_ai.usage.input_tokens"])
415-
if "gen_ai.usage.output_tokens" in span_attrs:
416-
output_tokens += int(span_attrs["gen_ai.usage.output_tokens"])
417-
418-
return {
419-
"total_tokens": total_tokens,
420-
"input_tokens": input_tokens,
421-
"output_tokens": output_tokens,
422-
}
407+
run = runtime.storage_runtime().metadb.get_run(run_id=run_id)
408+
if run.status in FINISHED_STATUS:
409+
if run.usage and "total_tokens" in run.usage:
410+
return {
411+
"total_tokens": run.usage.get("total_tokens", 0),
412+
"input_tokens": run.usage.get("input_tokens", 0),
413+
"output_tokens": run.usage.get("output_tokens", 0),
414+
}
415+
else:
416+
usage = GraphQLResolvers.get_run_usage(run_id)
417+
runtime.storage_runtime().metadb.update_run(
418+
run_id=run_id, usage=usage
419+
)
420+
return usage
421+
else:
422+
return GraphQLResolvers.get_run_usage(run_id)
423423
except Exception as e:
424424
import logging
425425

@@ -428,6 +428,33 @@ def aggregate_run_tokens(run_id: strawberry.ID) -> dict[str, int]:
428428
)
429429
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
430430

431+
@staticmethod
432+
def get_run_usage(run_id: strawberry.ID) -> dict[str, int]:
433+
trace_store = runtime.storage_runtime().tracestore
434+
spans = trace_store.get_llm_spans_by_run_id(run_id)
435+
# Don't close - it's a shared singleton connection
436+
437+
total_tokens = 0
438+
input_tokens = 0
439+
output_tokens = 0
440+
441+
for span in spans:
442+
span_attrs = span.get("SpanAttributes", {})
443+
444+
# Aggregate tokens from LLM spans
445+
if "llm.usage.total_tokens" in span_attrs:
446+
total_tokens += int(span_attrs["llm.usage.total_tokens"])
447+
if "gen_ai.usage.input_tokens" in span_attrs:
448+
input_tokens += int(span_attrs["gen_ai.usage.input_tokens"])
449+
if "gen_ai.usage.output_tokens" in span_attrs:
450+
output_tokens += int(span_attrs["gen_ai.usage.output_tokens"])
451+
452+
return {
453+
"total_tokens": total_tokens,
454+
"input_tokens": input_tokens,
455+
"output_tokens": output_tokens,
456+
}
457+
431458
@staticmethod
432459
def aggregate_experiment_tokens(experiment_id: strawberry.ID) -> dict[str, int]:
433460
"""Aggregate token usage from all spans in an experiment."""
@@ -436,31 +463,24 @@ def aggregate_experiment_tokens(experiment_id: strawberry.ID) -> dict[str, int]:
436463
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
437464

438465
try:
439-
trace_store = runtime.storage_runtime().tracestore
440-
# Get all LLM spans for this experiment in a single query
441-
spans = trace_store.get_llm_spans_by_exp_id(experiment_id)
442-
# Don't close - it's a shared singleton connection
443-
444-
total_tokens = 0
445-
input_tokens = 0
446-
output_tokens = 0
447-
448-
for span in spans:
449-
span_attrs = span.get("SpanAttributes", {})
450-
451-
# Aggregate tokens from LLM spans
452-
if "llm.usage.total_tokens" in span_attrs:
453-
total_tokens += int(span_attrs["llm.usage.total_tokens"])
454-
if "gen_ai.usage.input_tokens" in span_attrs:
455-
input_tokens += int(span_attrs["gen_ai.usage.input_tokens"])
456-
if "gen_ai.usage.output_tokens" in span_attrs:
457-
output_tokens += int(span_attrs["gen_ai.usage.output_tokens"])
458-
459-
return {
460-
"total_tokens": total_tokens,
461-
"input_tokens": input_tokens,
462-
"output_tokens": output_tokens,
463-
}
466+
exp = runtime.storage_runtime().metadb.get_experiment(
467+
experiment_id=experiment_id
468+
)
469+
if exp.status in FINISHED_STATUS:
470+
if exp.usage and "total_tokens" in exp.usage:
471+
return {
472+
"total_tokens": exp.usage.get("total_tokens", 0),
473+
"input_tokens": exp.usage.get("input_tokens", 0),
474+
"output_tokens": exp.usage.get("output_tokens", 0),
475+
}
476+
else:
477+
usage = GraphQLResolvers.get_experiment_usage(experiment_id)
478+
runtime.storage_runtime().metadb.update_experiment(
479+
experiment_id=experiment_id, usage=usage
480+
)
481+
return usage
482+
else:
483+
return GraphQLResolvers.get_experiment_usage(experiment_id)
464484
except Exception as e:
465485
import logging
466486

@@ -469,6 +489,34 @@ def aggregate_experiment_tokens(experiment_id: strawberry.ID) -> dict[str, int]:
469489
)
470490
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
471491

492+
@staticmethod
493+
def get_experiment_usage(experiment_id: strawberry.ID):
494+
trace_store = runtime.storage_runtime().tracestore
495+
# Get all LLM spans for this experiment in a single query
496+
spans = trace_store.get_llm_spans_by_exp_id(experiment_id)
497+
# Don't close - it's a shared singleton connection
498+
499+
total_tokens = 0
500+
input_tokens = 0
501+
output_tokens = 0
502+
503+
for span in spans:
504+
span_attrs = span.get("SpanAttributes", {})
505+
506+
# Aggregate tokens from LLM spans
507+
if "llm.usage.total_tokens" in span_attrs:
508+
total_tokens += int(span_attrs["llm.usage.total_tokens"])
509+
if "gen_ai.usage.input_tokens" in span_attrs:
510+
input_tokens += int(span_attrs["gen_ai.usage.input_tokens"])
511+
if "gen_ai.usage.output_tokens" in span_attrs:
512+
output_tokens += int(span_attrs["gen_ai.usage.output_tokens"])
513+
514+
return {
515+
"total_tokens": total_tokens,
516+
"input_tokens": input_tokens,
517+
"output_tokens": output_tokens,
518+
}
519+
472520
@staticmethod
473521
def list_spans(run_id: strawberry.ID) -> list[Span]:
474522
"""List all spans for a specific run."""

alphatrion/server/graphql/types.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,10 @@ class Experiment:
125125
params: JSON | None
126126
duration: float
127127
status: GraphQLStatusEnum
128+
cost: JSON | None
128129
created_at: datetime
129130
updated_at: datetime
130131

131-
_token_cache: strawberry.Private[dict[str, int] | None] = None
132-
133132
@strawberry.field
134133
def labels(self) -> list[Label]:
135134
from .resolvers import GraphQLResolvers
@@ -163,10 +162,9 @@ class Run:
163162
meta: JSON | None
164163
duration: float
165164
status: GraphQLStatusEnum
165+
cost: JSON | None
166166
created_at: datetime
167167

168-
_token_cache: strawberry.Private[dict[str, int] | None] = None
169-
170168
@strawberry.field
171169
def metrics(self) -> list["Metric"]:
172170
"""Get metrics for this run."""

alphatrion/storage/sql_models.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,17 @@ class Experiment(Base):
146146
0: UNKNOWN, 1: PENDING, 2: RUNNING, 9: COMPLETED, \
147147
10: CANCELLED, 11: FAILED",
148148
)
149+
usage = Column(
150+
MutableDict.as_mutable(JSON),
151+
nullable=True,
152+
comment="The usage information, e.g. for LLM calls: \
153+
{total_tokens: int, input_tokens: int, output_tokens: int}",
154+
)
155+
cost = Column(
156+
MutableDict.as_mutable(JSON),
157+
nullable=True,
158+
comment="Cost of the experiment in dollars",
159+
)
149160

150161
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
151162
updated_at = Column(
@@ -190,6 +201,17 @@ class Run(Base):
190201
0: UNKNOWN, 1: PENDING, 2: RUNNING, 9: COMPLETED, \
191202
10: CANCELLED, 11: FAILED",
192203
)
204+
usage = Column(
205+
MutableDict.as_mutable(JSON),
206+
nullable=True,
207+
comment="The usage information, e.g. for LLM calls: \
208+
{total_tokens: int, input_tokens: int, output_tokens: int}",
209+
)
210+
cost = Column(
211+
MutableDict.as_mutable(JSON),
212+
nullable=True,
213+
comment="Cost of the run in dollars",
214+
)
193215

194216
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
195217
updated_at = Column(

alphatrion/storage/sqlstore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def get_run(self, run_id: uuid.UUID) -> Run | None:
658658

659659
def list_runs_by_exp_id(
660660
self,
661-
exp_id: uuid.UUID,
661+
experiment_id: uuid.UUID,
662662
page: int = 0,
663663
page_size: int = 10,
664664
order_by: str = "created_at",
@@ -667,7 +667,7 @@ def list_runs_by_exp_id(
667667
session = self._session()
668668
runs = (
669669
session.query(Run)
670-
.filter(Run.experiment_id == exp_id, Run.is_del == 0)
670+
.filter(Run.experiment_id == experiment_id, Run.is_del == 0)
671671
.order_by(
672672
getattr(Run, order_by).desc() if order_desc else getattr(Run, order_by)
673673
)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""add usage and cost fields
2+
3+
Revision ID: fd4984c761c2
4+
Revises: 0f417c7cf4d3
5+
Create Date: 2026-03-07 09:24:35.803615
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = 'fd4984c761c2'
16+
down_revision: Union[str, Sequence[str], None] = '0f417c7cf4d3'
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
"""Upgrade schema."""
23+
# ### commands auto generated by Alembic - please adjust! ###
24+
op.add_column('experiments', sa.Column('usage', sa.JSON(), nullable=True, comment='The usage information, e.g. for LLM calls: {total_tokens: int, input_tokens: int, output_tokens: int}'))
25+
op.add_column('experiments', sa.Column('cost', sa.JSON(), nullable=True, comment='Cost of the run in dollars'))
26+
op.add_column('runs', sa.Column('usage', sa.JSON(), nullable=True, comment='The usage information, e.g. for LLM calls: {total_tokens: int, input_tokens: int, output_tokens: int}'))
27+
op.add_column('runs', sa.Column('cost', sa.JSON(), nullable=True, comment='Cost of the run in dollars'))
28+
# ### end Alembic commands ###
29+
30+
31+
def downgrade() -> None:
32+
"""Downgrade schema."""
33+
# ### commands auto generated by Alembic - please adjust! ###
34+
op.drop_column('runs', 'cost')
35+
op.drop_column('runs', 'usage')
36+
op.drop_column('experiments', 'cost')
37+
op.drop_column('experiments', 'usage')
38+
# ### end Alembic commands ###

tests/integration/server/test_graphql_mutation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,8 +328,8 @@ def test_add_user_to_team_with_invalid_user():
328328
assert "not found" in str(response.errors[0])
329329

330330

331-
def test_complete_workflow():
332-
"""Test complete workflow: create team, create user, add user to teams"""
331+
def test_user_workflow():
332+
"""Test user workflow: create team, create user, add user to teams"""
333333
runtime.init()
334334

335335
username = unique_username("alice")

0 commit comments

Comments
 (0)