88from alphatrion import envs
99from alphatrion .artifact import artifact
1010from alphatrion .storage import runtime
11- from alphatrion .storage .sql_models import Status
11+ from alphatrion .storage .sql_models import (
12+ FINISHED_STATUS ,
13+ Status ,
14+ )
1215
1316from .types import (
1417 AddUserToTeamInput ,
@@ -138,6 +141,7 @@ def list_experiments(
138141 duration = e .duration ,
139142 status = GraphQLStatusEnum [Status (e .status ).name ],
140143 kind = GraphQLExperimentTypeEnum [GraphQLExperimentType (e .kind ).name ],
144+ cost = e .cost ,
141145 created_at = e .created_at ,
142146 updated_at = e .updated_at ,
143147 )
@@ -160,6 +164,7 @@ def get_experiment(id: strawberry.ID) -> Experiment | None:
160164 duration = exp .duration ,
161165 status = GraphQLStatusEnum [Status (exp .status ).name ],
162166 kind = GraphQLExperimentTypeEnum [GraphQLExperimentType (exp .kind ).name ],
167+ cost = exp .cost ,
163168 created_at = exp .created_at ,
164169 updated_at = exp .updated_at ,
165170 )
@@ -175,7 +180,7 @@ def list_runs(
175180 ) -> list [Run ]:
176181 metadb = runtime .storage_runtime ().metadb
177182 runs = metadb .list_runs_by_exp_id (
178- exp_id = uuid .UUID (experiment_id ),
183+ experiment_id = uuid .UUID (experiment_id ),
179184 page = page ,
180185 page_size = page_size ,
181186 order_by = order_by ,
@@ -190,6 +195,7 @@ def list_runs(
190195 meta = r .meta ,
191196 status = GraphQLStatusEnum [Status (r .status ).name ],
192197 duration = r .duration ,
198+ cost = r .cost ,
193199 created_at = r .created_at ,
194200 )
195201 for r in runs
@@ -208,6 +214,7 @@ def get_run(id: strawberry.ID) -> Run | None:
208214 meta = run .meta ,
209215 status = GraphQLStatusEnum [Status (run .status ).name ],
210216 duration = run .duration ,
217+ cost = run .cost ,
211218 created_at = run .created_at ,
212219 )
213220 return None
@@ -311,6 +318,7 @@ def list_exps_by_timeframe(
311318 duration = e .duration ,
312319 status = GraphQLStatusEnum [Status (e .status ).name ],
313320 kind = GraphQLExperimentTypeEnum [GraphQLExperimentType (e .kind ).name ],
321+ cost = e .cost ,
314322 created_at = e .created_at ,
315323 updated_at = e .updated_at ,
316324 )
@@ -396,30 +404,22 @@ def aggregate_run_tokens(run_id: strawberry.ID) -> dict[str, int]:
396404 return {"total_tokens" : 0 , "input_tokens" : 0 , "output_tokens" : 0 }
397405
398406 try :
399- trace_store = runtime .storage_runtime ().tracestore
400- spans = trace_store .get_llm_spans_by_run_id (run_id )
401- # Don't close - it's a shared singleton connection
402-
403- total_tokens = 0
404- input_tokens = 0
405- output_tokens = 0
406-
407- for span in spans :
408- span_attrs = span .get ("SpanAttributes" , {})
409-
410- # Aggregate tokens from LLM spans
411- if "llm.usage.total_tokens" in span_attrs :
412- total_tokens += int (span_attrs ["llm.usage.total_tokens" ])
413- if "gen_ai.usage.input_tokens" in span_attrs :
414- input_tokens += int (span_attrs ["gen_ai.usage.input_tokens" ])
415- if "gen_ai.usage.output_tokens" in span_attrs :
416- output_tokens += int (span_attrs ["gen_ai.usage.output_tokens" ])
417-
418- return {
419- "total_tokens" : total_tokens ,
420- "input_tokens" : input_tokens ,
421- "output_tokens" : output_tokens ,
422- }
407+ run = runtime .storage_runtime ().metadb .get_run (run_id = run_id )
408+ if run .status in FINISHED_STATUS :
409+ if run .usage and "total_tokens" in run .usage :
410+ return {
411+ "total_tokens" : run .usage .get ("total_tokens" , 0 ),
412+ "input_tokens" : run .usage .get ("input_tokens" , 0 ),
413+ "output_tokens" : run .usage .get ("output_tokens" , 0 ),
414+ }
415+ else :
416+ usage = GraphQLResolvers .get_run_usage (run_id )
417+ runtime .storage_runtime ().metadb .update_run (
418+ run_id = run_id , usage = usage
419+ )
420+ return usage
421+ else :
422+ return GraphQLResolvers .get_run_usage (run_id )
423423 except Exception as e :
424424 import logging
425425
@@ -428,6 +428,33 @@ def aggregate_run_tokens(run_id: strawberry.ID) -> dict[str, int]:
428428 )
429429 return {"total_tokens" : 0 , "input_tokens" : 0 , "output_tokens" : 0 }
430430
431+ @staticmethod
432+ def get_run_usage (run_id : strawberry .ID ) -> dict [str , int ]:
433+ trace_store = runtime .storage_runtime ().tracestore
434+ spans = trace_store .get_llm_spans_by_run_id (run_id )
435+ # Don't close - it's a shared singleton connection
436+
437+ total_tokens = 0
438+ input_tokens = 0
439+ output_tokens = 0
440+
441+ for span in spans :
442+ span_attrs = span .get ("SpanAttributes" , {})
443+
444+ # Aggregate tokens from LLM spans
445+ if "llm.usage.total_tokens" in span_attrs :
446+ total_tokens += int (span_attrs ["llm.usage.total_tokens" ])
447+ if "gen_ai.usage.input_tokens" in span_attrs :
448+ input_tokens += int (span_attrs ["gen_ai.usage.input_tokens" ])
449+ if "gen_ai.usage.output_tokens" in span_attrs :
450+ output_tokens += int (span_attrs ["gen_ai.usage.output_tokens" ])
451+
452+ return {
453+ "total_tokens" : total_tokens ,
454+ "input_tokens" : input_tokens ,
455+ "output_tokens" : output_tokens ,
456+ }
457+
431458 @staticmethod
432459 def aggregate_experiment_tokens (experiment_id : strawberry .ID ) -> dict [str , int ]:
433460 """Aggregate token usage from all spans in an experiment."""
@@ -436,31 +463,24 @@ def aggregate_experiment_tokens(experiment_id: strawberry.ID) -> dict[str, int]:
436463 return {"total_tokens" : 0 , "input_tokens" : 0 , "output_tokens" : 0 }
437464
438465 try :
439- trace_store = runtime .storage_runtime ().tracestore
440- # Get all LLM spans for this experiment in a single query
441- spans = trace_store .get_llm_spans_by_exp_id (experiment_id )
442- # Don't close - it's a shared singleton connection
443-
444- total_tokens = 0
445- input_tokens = 0
446- output_tokens = 0
447-
448- for span in spans :
449- span_attrs = span .get ("SpanAttributes" , {})
450-
451- # Aggregate tokens from LLM spans
452- if "llm.usage.total_tokens" in span_attrs :
453- total_tokens += int (span_attrs ["llm.usage.total_tokens" ])
454- if "gen_ai.usage.input_tokens" in span_attrs :
455- input_tokens += int (span_attrs ["gen_ai.usage.input_tokens" ])
456- if "gen_ai.usage.output_tokens" in span_attrs :
457- output_tokens += int (span_attrs ["gen_ai.usage.output_tokens" ])
458-
459- return {
460- "total_tokens" : total_tokens ,
461- "input_tokens" : input_tokens ,
462- "output_tokens" : output_tokens ,
463- }
466+ exp = runtime .storage_runtime ().metadb .get_experiment (
467+ experiment_id = experiment_id
468+ )
469+ if exp .status in FINISHED_STATUS :
470+ if exp .usage and "total_tokens" in exp .usage :
471+ return {
472+ "total_tokens" : exp .usage .get ("total_tokens" , 0 ),
473+ "input_tokens" : exp .usage .get ("input_tokens" , 0 ),
474+ "output_tokens" : exp .usage .get ("output_tokens" , 0 ),
475+ }
476+ else :
477+ usage = GraphQLResolvers .get_experiment_usage (experiment_id )
478+ runtime .storage_runtime ().metadb .update_experiment (
479+ experiment_id = experiment_id , usage = usage
480+ )
481+ return usage
482+ else :
483+ return GraphQLResolvers .get_experiment_usage (experiment_id )
464484 except Exception as e :
465485 import logging
466486
@@ -469,6 +489,34 @@ def aggregate_experiment_tokens(experiment_id: strawberry.ID) -> dict[str, int]:
469489 )
470490 return {"total_tokens" : 0 , "input_tokens" : 0 , "output_tokens" : 0 }
471491
492+ @staticmethod
493+ def get_experiment_usage (experiment_id : strawberry .ID ):
494+ trace_store = runtime .storage_runtime ().tracestore
495+ # Get all LLM spans for this experiment in a single query
496+ spans = trace_store .get_llm_spans_by_exp_id (experiment_id )
497+ # Don't close - it's a shared singleton connection
498+
499+ total_tokens = 0
500+ input_tokens = 0
501+ output_tokens = 0
502+
503+ for span in spans :
504+ span_attrs = span .get ("SpanAttributes" , {})
505+
506+ # Aggregate tokens from LLM spans
507+ if "llm.usage.total_tokens" in span_attrs :
508+ total_tokens += int (span_attrs ["llm.usage.total_tokens" ])
509+ if "gen_ai.usage.input_tokens" in span_attrs :
510+ input_tokens += int (span_attrs ["gen_ai.usage.input_tokens" ])
511+ if "gen_ai.usage.output_tokens" in span_attrs :
512+ output_tokens += int (span_attrs ["gen_ai.usage.output_tokens" ])
513+
514+ return {
515+ "total_tokens" : total_tokens ,
516+ "input_tokens" : input_tokens ,
517+ "output_tokens" : output_tokens ,
518+ }
519+
472520 @staticmethod
473521 def list_spans (run_id : strawberry .ID ) -> list [Span ]:
474522 """List all spans for a specific run."""
0 commit comments