Skip to content

Commit ecec03c

Browse files
committed
feat: model distribution (InftyAI#189)
* update helm chart Signed-off-by: kerthcet <kerthcet@gmail.com> * add artifacts Signed-off-by: kerthcet <kerthcet@gmail.com> * optimize Signed-off-by: kerthcet <kerthcet@gmail.com> * add artifacts page Signed-off-by: kerthcet <kerthcet@gmail.com> * add model distribution Signed-off-by: kerthcet <kerthcet@gmail.com> * fix error Signed-off-by: kerthcet <kerthcet@gmail.com> * fix lint Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent eed14ff commit ecec03c

20 files changed

Lines changed: 1066 additions & 662 deletions

alphatrion/server/graphql/resolvers.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
GraphQLStatusEnum,
2525
Label,
2626
Metric,
27+
ModelDistribution,
2728
RemoveUserFromTeamInput,
2829
Run,
2930
Span,
@@ -263,12 +264,28 @@ def aggregate_team_tokens(team_id: strawberry.ID) -> dict[str, int]:
263264
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
264265

265266
trace_store = runtime.storage_runtime().tracestore
266-
result = trace_store.get_llm_spans_by_team_id(team_id=team_id)
267-
# get_llm_spans_by_team_id returns a list with one dict
267+
result = trace_store.get_llm_tokens_by_team_id(team_id=team_id)
268+
# get_llm_tokens_by_team_id returns a list with one dict
268269
if result and len(result) > 0:
269270
return result[0]
270271
return {"total_tokens": 0, "input_tokens": 0, "output_tokens": 0}
271272

273+
@staticmethod
274+
def aggregate_model_distributions(
275+
team_id: strawberry.ID,
276+
) -> list[ModelDistribution]:
277+
from alphatrion import envs
278+
279+
if os.getenv(envs.ENABLE_TRACING, "false").lower() != "true":
280+
return []
281+
282+
trace_store = runtime.storage_runtime().tracestore
283+
result = trace_store.get_model_distributions_by_team_id(team_id=team_id)
284+
return [
285+
ModelDistribution(model=item["model"], count=item["count"])
286+
for item in result
287+
]
288+
272289
@staticmethod
273290
def list_exps_by_timeframe(
274291
team_id: strawberry.ID,

alphatrion/server/graphql/types.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ class TokenStats:
1313
output_tokens: int
1414

1515

16+
@strawberry.type
17+
class ModelDistribution:
18+
model: str
19+
count: int
20+
21+
1622
@strawberry.type
1723
class Team:
1824
id: strawberry.ID
@@ -45,6 +51,12 @@ def aggregated_tokens(self) -> TokenStats:
4551
output_tokens=token_data["output_tokens"],
4652
)
4753

54+
@strawberry.field
55+
def model_distributions(self) -> list["ModelDistribution"]:
56+
from .resolvers import GraphQLResolvers
57+
58+
return GraphQLResolvers.aggregate_model_distributions(team_id=self.id)
59+
4860
@strawberry.field
4961
def exps_by_timeframe(
5062
self, start_time: datetime, end_time: datetime

alphatrion/storage/tracestore.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def get_llm_spans_by_exp_id(self, exp_id: uuid.UUID) -> list[dict[str, Any]]:
331331
logger.error(f"Failed to get spans by exp_id: {e}")
332332
return []
333333

334-
def get_llm_spans_by_team_id(self, team_id: uuid.UUID) -> list[dict[str, Any]]:
334+
def get_llm_tokens_by_team_id(self, team_id: uuid.UUID) -> list[dict[str, Any]]:
335335
"""Get all LLM spans for a specific team_id.
336336
337337
Args:
@@ -365,6 +365,46 @@ def get_llm_spans_by_team_id(self, team_id: uuid.UUID) -> list[dict[str, Any]]:
365365
logger.error(f"Failed to get daily token usage: {e}")
366366
return []
367367

368+
def get_model_distributions_by_team_id(
369+
self, team_id: uuid.UUID
370+
) -> list[dict[str, Any]]:
371+
"""Get model distribution (count of requests per model) for a specific team.
372+
373+
Args:
374+
team_id: The team ID to filter by
375+
376+
Returns:
377+
List of dicts with keys: model, count
378+
"""
379+
with self._lock:
380+
try:
381+
query = f"""
382+
SELECT
383+
coalesce(
384+
SpanAttributes['gen_ai.request.model'],
385+
SpanAttributes['gen_ai.response.model'],
386+
'unknown'
387+
) as model,
388+
COUNT(*) as count
389+
FROM {self.database}.otel_spans
390+
WHERE TeamId = '{team_id}'
391+
AND SemanticKind = 'llm'
392+
GROUP BY model
393+
ORDER BY count DESC
394+
"""
395+
396+
result = self.client.query(query)
397+
return [
398+
{
399+
"model": row["model"],
400+
"count": int(row["count"]),
401+
}
402+
for row in result.named_results()
403+
]
404+
except Exception as e:
405+
logger.error(f"Failed to get model distributions: {e}")
406+
return []
407+
368408
def get_daily_token_usage(
369409
self, team_id: uuid.UUID, days: int = 30
370410
) -> list[dict[str, Any]]:

0 commit comments

Comments
 (0)