@@ -220,13 +220,23 @@ def get_run(id: strawberry.ID) -> Run | None:
220220 metadb = runtime .storage_runtime ().metadb
221221 run = metadb .get_run (run_id = uuid .UUID (id ))
222222 if run :
223+ meta = run .meta or {}
224+
225+ # Aggregate and cache tokens for completed runs.
226+ # It could be slow for the first time.
227+ if Status (run .status ) == Status .COMPLETED and "total_tokens" not in meta :
228+ token_data = GraphQLResolvers .aggregate_run_tokens (run_id = id )
229+ if token_data ["total_tokens" ] > 0 :
230+ meta .update (token_data )
231+ metadb .update_run (run_id = uuid .UUID (id ), meta = meta )
232+
223233 return Run (
224234 id = run .uuid ,
225235 team_id = run .team_id ,
226236 user_id = run .user_id ,
227237 project_id = run .project_id ,
228238 experiment_id = run .experiment_id ,
229- meta = run . meta ,
239+ meta = meta ,
230240 status = GraphQLStatusEnum [Status (run .status ).name ],
231241 created_at = run .created_at ,
232242 )
@@ -250,6 +260,24 @@ def list_exp_metrics(experiment_id: strawberry.ID) -> list[Metric]:
250260 for m in metrics
251261 ]
252262
263+ @staticmethod
264+ def list_run_metrics (run_id : strawberry .ID ) -> list [Metric ]:
265+ metadb = runtime .storage_runtime ().metadb
266+ metrics = metadb .list_metrics_by_run_id (run_id = run_id )
267+ return [
268+ Metric (
269+ id = m .uuid ,
270+ key = m .key ,
271+ value = m .value ,
272+ team_id = m .team_id ,
273+ project_id = m .project_id ,
274+ experiment_id = m .experiment_id ,
275+ run_id = m .run_id ,
276+ created_at = m .created_at ,
277+ )
278+ for m in metrics
279+ ]
280+
253281 @staticmethod
254282 def total_projects (team_id : strawberry .ID ) -> int :
255283 metadb = runtime .storage_runtime ().metadb
@@ -373,8 +401,48 @@ async def get_artifact_content(
373401 raise RuntimeError (f"Failed to get artifact content: { e } " ) from e
374402
375403 @staticmethod
376- def list_traces (run_id : strawberry .ID ) -> list [Span ]:
377- """List all traces/spans for a specific run."""
404+ def aggregate_run_tokens (run_id : strawberry .ID ) -> dict [str , int ]:
405+ """Aggregate token usage from all traces for a run."""
406+ from alphatrion import envs
407+
408+ # Check if tracing is enabled
409+ if os .getenv (envs .ENABLE_TRACING , "false" ).lower () != "true" :
410+ return {"total_tokens" : 0 , "input_tokens" : 0 , "output_tokens" : 0 }
411+
412+ try :
413+ trace_store = runtime .storage_runtime ().tracestore
414+ spans = trace_store .get_spans_by_run_id (uuid .UUID (run_id ))
415+ trace_store .close ()
416+
417+ total_tokens = 0
418+ input_tokens = 0
419+ output_tokens = 0
420+
421+ for span in spans :
422+ span_attrs = span .get ("SpanAttributes" , {})
423+
424+ # Aggregate tokens from LLM spans
425+ if "llm.usage.total_tokens" in span_attrs :
426+ total_tokens += int (span_attrs ["llm.usage.total_tokens" ])
427+ if "gen_ai.usage.input_tokens" in span_attrs :
428+ input_tokens += int (span_attrs ["gen_ai.usage.input_tokens" ])
429+ if "gen_ai.usage.output_tokens" in span_attrs :
430+ output_tokens += int (span_attrs ["gen_ai.usage.output_tokens" ])
431+
432+ return {
433+ "total_tokens" : total_tokens ,
434+ "input_tokens" : input_tokens ,
435+ "output_tokens" : output_tokens ,
436+ }
437+ except Exception as e :
438+ import logging
439+
440+ logging .error (f"Failed to aggregate tokens for run { run_id } : { e } " )
441+ return {"total_tokens" : 0 , "input_tokens" : 0 , "output_tokens" : 0 }
442+
443+ @staticmethod
444+ def list_spans (run_id : strawberry .ID ) -> list [Span ]:
445+ """List all spans for a specific run."""
378446 from alphatrion import envs
379447
380448 # Check if tracing is enabled
@@ -385,12 +453,12 @@ def list_traces(run_id: strawberry.ID) -> list[Span]:
385453 trace_store = runtime .storage_runtime ().tracestore
386454
387455 # Get traces from ClickHouse
388- traces = trace_store .get_traces_by_run_id (uuid .UUID (run_id ))
456+ raw_spans = trace_store .get_spans_by_run_id (uuid .UUID (run_id ))
389457 trace_store .close ()
390458
391459 # Convert to GraphQL Span objects
392460 spans = []
393- for t in traces :
461+ for t in raw_spans :
394462 # Convert events
395463 events = []
396464 if t .get ("Events" ):
0 commit comments