Skip to content

Commit 418e2d9

Browse files
authored
feat: add model id on token usage (#2507)
1 parent 8e0f7ce commit 418e2d9

8 files changed

Lines changed: 250 additions & 3 deletions

File tree

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""0015 token_usage model_id — record which model each call ran under.
2+
3+
Adds ``token_usage.model_id`` (canonical id: catalog name for built-ins,
4+
UUID for BYOM) so analytics can group spend by model. The partial index
5+
mirrors ``token_usage_request_id_idx`` — it excludes the NULL rows that
6+
pre-date the column.
7+
8+
Revision ID: 0015_token_usage_model_id
9+
Revises: 0014_device_token_hash_index
10+
"""
11+
12+
from typing import Sequence, Union
13+
14+
from alembic import op
15+
16+
17+
revision: str = "0015_token_usage_model_id"
18+
down_revision: Union[str, None] = "0014_device_token_hash_index"
19+
branch_labels: Union[str, Sequence[str], None] = None
20+
depends_on: Union[str, Sequence[str], None] = None
21+
22+
23+
def upgrade() -> None:
24+
op.execute("ALTER TABLE token_usage ADD COLUMN model_id TEXT;")
25+
op.execute(
26+
'CREATE INDEX token_usage_model_ts_idx '
27+
'ON token_usage (model_id, "timestamp" DESC) '
28+
"WHERE model_id IS NOT NULL;"
29+
)
30+
31+
32+
def downgrade() -> None:
33+
op.execute("DROP INDEX IF EXISTS token_usage_model_ts_idx;")
34+
op.execute("ALTER TABLE token_usage DROP COLUMN IF EXISTS model_id;")

application/api/user/scheduler_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def execute_scheduled_run_body(run_id: str, celery_task_id: Optional[str]) -> Di
349349
agent_id=str(agent_id_raw) if agent_id_raw else None,
350350
source="schedule",
351351
request_id=str(run_id),
352+
model_id=outcome.get("model_id"),
352353
)
353354
except Exception:
354355
logger.exception(

application/llm/llm_creator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def create_llm(
111111

112112
# Forward model_user_id so backup/fallback resolves under the
113113
# owner's scope on shared-agent dispatch.
114-
return plugin.llm_class(
114+
llm = plugin.llm_class(
115115
api_key,
116116
user_api_key,
117117
decoded_token=decoded_token,
@@ -124,3 +124,7 @@ def create_llm(
124124
*args,
125125
**kwargs,
126126
)
127+
# llm.model_id is the upstream name (BYOM resolves it above); stamp
128+
# the canonical id (UUID for BYOM) separately for token_usage.
129+
llm._canonical_model_id = model_id
130+
return llm

application/storage/db/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@
103103
# N rows) count as a single request via DISTINCT in the repository
104104
# query. NULL on side-channel sources by design.
105105
Column("request_id", Text),
106+
# Added in ``0015_token_usage_model_id``. Canonical model id (catalog
107+
# name for built-ins, UUID for BYOM); NULL on un-backfilled rows.
108+
Column("model_id", Text),
106109
)
107110

108111
user_logs_table = Table(

application/storage/db/repositories/token_usage.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def insert(
3333
generated_tokens: int = 0,
3434
source: str = "agent_stream",
3535
request_id: Optional[str] = None,
36+
model_id: Optional[str] = None,
3637
timestamp: Optional[datetime] = None,
3738
) -> None:
3839
# Attribution guard: the ``token_usage_attribution_chk`` CHECK
@@ -59,13 +60,13 @@ def insert(
5960
INSERT INTO token_usage (
6061
user_id, api_key, agent_id,
6162
prompt_tokens, generated_tokens,
62-
source, request_id, timestamp
63+
source, request_id, model_id, timestamp
6364
)
6465
VALUES (
6566
:user_id, :api_key,
6667
CAST(:agent_id AS uuid),
6768
:prompt_tokens, :generated_tokens,
68-
:source, :request_id, COALESCE(:timestamp, now())
69+
:source, :request_id, :model_id, COALESCE(:timestamp, now())
6970
)
7071
"""
7172
),
@@ -77,6 +78,7 @@ def insert(
7778
"generated_tokens": generated_tokens,
7879
"source": source,
7980
"request_id": request_id,
81+
"model_id": model_id,
8082
"timestamp": timestamp,
8183
},
8284
)

application/usage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _persist_call_usage(llm, call_usage):
134134
getattr(llm, "_token_usage_source", None) or "agent_stream"
135135
),
136136
request_id=getattr(llm, "_request_id", None),
137+
model_id=getattr(llm, "_canonical_model_id", None),
137138
)
138139
except Exception:
139140
logger.exception("token_usage persist failed")
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Backfill ``token_usage.model_id`` for rows written before the column.
2+
3+
New rows get ``model_id`` stamped at write time (see
4+
``application.llm.llm_creator`` / ``application.usage``). This script
5+
fills the historical NULLs by deriving the model from data we already
6+
trust, in priority order. A row is only ever filled by the
7+
highest-priority tier that matches it; tiers run in one transaction so
8+
each later tier sees only the rows still NULL.
9+
10+
Tiers (both touch only ``source='agent_stream'`` rows)
11+
-----
12+
1. ``request_id`` join (high confidence). The route stamps the same
13+
``request_id`` on the token_usage row and the assistant message, so
14+
``conversation_messages.model_id`` is authoritative for the call.
15+
2. ``agent_id`` + nearest message (medium confidence). For primary rows
16+
with no usable ``request_id`` (legacy), copy ``model_id`` from the
17+
closest-in-time message of any conversation belonging to the same
18+
agent, within ``--window-minutes`` (ties broken toward the later
19+
message so re-runs are reproducible).
20+
21+
Side-channel rows (``fallback`` / ``compression`` / ``title`` /
22+
``rag_condense`` / ``schedule``) are left NULL: they share the primary
23+
turn's ``request_id`` or agent but often ran a *different* model (a
24+
backup, a compression override), so copying the primary turn's model
25+
onto them would mis-attribute spend. New rows already get the correct
26+
per-call model stamped at write time, so this only concerns history.
27+
28+
Rows that match neither tier are left NULL on purpose — the partial
29+
index ``token_usage_model_ts_idx`` excludes them, and a model we can't
30+
tie to the specific call (e.g. the agent's configured default) would
31+
poison the analytics it feeds.
32+
33+
Both ``model_id`` columns store the canonical id (catalog name for
34+
built-ins, UUID for BYOM), so BYOM rows backfill to the UUID unchanged.
35+
36+
Usage::
37+
38+
# Dry-run (default): runs the fills in a rolled-back transaction and
39+
# reports exactly how many rows each tier would touch.
40+
python scripts/db/backfill_token_usage_model_id.py
41+
42+
# Commit the backfill.
43+
python scripts/db/backfill_token_usage_model_id.py --apply
44+
45+
# Widen the tier-2 match window (default 5 minutes).
46+
python scripts/db/backfill_token_usage_model_id.py --window-minutes 10 --apply
47+
48+
Exit codes:
49+
0 — success (dry-run or apply)
50+
1 — bad arguments
51+
"""
52+
53+
from __future__ import annotations
54+
55+
import argparse
56+
import sys
57+
from pathlib import Path
58+
59+
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
60+
61+
from sqlalchemy import text # noqa: E402
62+
63+
from application.storage.db.engine import get_engine # noqa: E402
64+
65+
66+
# Tier 1: same request -> same model, primary (agent_stream) rows only.
67+
# conversation_messages.model_id is authoritative for that turn; fallback
68+
# / compression rows share the request_id but ran a different model.
69+
_TIER1 = text(
70+
"""
71+
UPDATE token_usage tu
72+
SET model_id = cm.model_id
73+
FROM conversation_messages cm
74+
WHERE cm.request_id = tu.request_id
75+
AND cm.model_id IS NOT NULL
76+
AND tu.model_id IS NULL
77+
AND tu.request_id IS NOT NULL
78+
AND tu.source = 'agent_stream'
79+
"""
80+
)
81+
82+
# Tier 2: nearest message of the same agent within the window, primary
83+
# (agent_stream) rows only. The EXISTS mirror skips rows with no match
84+
# (else the subquery would set NULL); the ORDER BY tiebreak (later message
85+
# wins) keeps the pick reproducible across re-runs.
86+
_TIER2 = text(
87+
"""
88+
UPDATE token_usage tu
89+
SET model_id = (
90+
SELECT cm.model_id
91+
FROM conversation_messages cm
92+
JOIN conversations c ON c.id = cm.conversation_id
93+
WHERE c.agent_id = tu.agent_id
94+
AND cm.model_id IS NOT NULL
95+
AND cm.timestamp BETWEEN tu.timestamp - make_interval(mins => :win)
96+
AND tu.timestamp + make_interval(mins => :win)
97+
ORDER BY abs(extract(epoch FROM (cm.timestamp - tu.timestamp))), cm.timestamp DESC
98+
LIMIT 1
99+
)
100+
WHERE tu.model_id IS NULL
101+
AND tu.agent_id IS NOT NULL
102+
AND tu.source = 'agent_stream'
103+
AND EXISTS (
104+
SELECT 1
105+
FROM conversation_messages cm
106+
JOIN conversations c ON c.id = cm.conversation_id
107+
WHERE c.agent_id = tu.agent_id
108+
AND cm.model_id IS NOT NULL
109+
AND cm.timestamp BETWEEN tu.timestamp - make_interval(mins => :win)
110+
AND tu.timestamp + make_interval(mins => :win)
111+
)
112+
"""
113+
)
114+
115+
_COUNT_NULL = text("SELECT count(*) FROM token_usage WHERE model_id IS NULL")
116+
117+
118+
def main() -> int:
119+
parser = argparse.ArgumentParser(
120+
description="Backfill token_usage.model_id from existing data.",
121+
)
122+
parser.add_argument(
123+
"--apply",
124+
action="store_true",
125+
help="Commit the backfill. Default is a rolled-back dry-run.",
126+
)
127+
parser.add_argument(
128+
"--window-minutes",
129+
type=int,
130+
default=5,
131+
metavar="N",
132+
help="Tier-2 nearest-message match window, in minutes (default 5).",
133+
)
134+
args = parser.parse_args()
135+
136+
if args.window_minutes < 0:
137+
print("--window-minutes must be >= 0", file=sys.stderr)
138+
return 1
139+
140+
engine = get_engine()
141+
with engine.connect() as conn:
142+
trans = conn.begin()
143+
try:
144+
# A one-shot maintenance UPDATE can run well past the engine's
145+
# 30s per-statement guardrail; lift it for this transaction.
146+
conn.execute(text("SET LOCAL statement_timeout = 0"))
147+
148+
before = conn.execute(_COUNT_NULL).scalar_one()
149+
150+
t1 = conn.execute(_TIER1).rowcount or 0
151+
t2 = conn.execute(_TIER2, {"win": args.window_minutes}).rowcount or 0
152+
153+
after = conn.execute(_COUNT_NULL).scalar_one()
154+
155+
print(f"NULL model_id rows before: {before}")
156+
print(f" tier 1 (request_id): {t1}")
157+
print(f" tier 2 (agent + nearest msg): {t2}")
158+
print(f"NULL model_id rows remaining: {after}")
159+
160+
if args.apply:
161+
trans.commit()
162+
print("\nCommitted.")
163+
else:
164+
trans.rollback()
165+
print("\nDry run — rolled back. Re-run with --apply to commit.")
166+
except Exception:
167+
trans.rollback()
168+
raise
169+
170+
return 0
171+
172+
173+
if __name__ == "__main__":
174+
sys.exit(main())

tests/storage/db/repositories/test_token_usage.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from datetime import datetime, timedelta, timezone
66

77
import pytest
8+
from sqlalchemy import text
89

910
from application.storage.db.repositories.token_usage import TokenUsageRepository
1011

@@ -35,6 +36,33 @@ def test_insert_with_api_key(self, pg_conn):
3536
assert total == 30
3637

3738

39+
class TestModelId:
40+
def _latest_model_id(self, conn, user_id):
41+
return conn.execute(
42+
text(
43+
"SELECT model_id FROM token_usage "
44+
"WHERE user_id = :u ORDER BY id DESC LIMIT 1"
45+
),
46+
{"u": user_id},
47+
).scalar_one()
48+
49+
def test_persists_model_id(self, pg_conn):
50+
repo = _repo(pg_conn)
51+
repo.insert(user_id="u-model", prompt_tokens=1, generated_tokens=1, model_id="gpt-4o")
52+
assert self._latest_model_id(pg_conn, "u-model") == "gpt-4o"
53+
54+
def test_persists_byom_uuid_model_id(self, pg_conn):
55+
repo = _repo(pg_conn)
56+
uuid_id = "11111111-1111-1111-1111-111111111111"
57+
repo.insert(user_id="u-byom", prompt_tokens=1, generated_tokens=1, model_id=uuid_id)
58+
assert self._latest_model_id(pg_conn, "u-byom") == uuid_id
59+
60+
def test_model_id_defaults_to_null(self, pg_conn):
61+
repo = _repo(pg_conn)
62+
repo.insert(user_id="u-nomodel", prompt_tokens=1, generated_tokens=1)
63+
assert self._latest_model_id(pg_conn, "u-nomodel") is None
64+
65+
3866
class TestSumTokensInRange:
3967
def test_sums_correctly(self, pg_conn):
4068
repo = _repo(pg_conn)

0 commit comments

Comments
 (0)