Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions application/alembic/versions/0015_token_usage_model_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""0015 token_usage model_id — record which model each call ran under.

Adds ``token_usage.model_id`` (canonical id: catalog name for built-ins,
UUID for BYOM) so analytics can group spend by model. The partial index
mirrors ``token_usage_request_id_idx`` — it excludes the NULL rows that
pre-date the column.

Revision ID: 0015_token_usage_model_id
Revises: 0014_device_token_hash_index
"""

from typing import Sequence, Union

from alembic import op


revision: str = "0015_token_usage_model_id"
Comment thread
dartpain marked this conversation as resolved.
Dismissed
down_revision: Union[str, None] = "0014_device_token_hash_index"
Comment thread
dartpain marked this conversation as resolved.
Dismissed
branch_labels: Union[str, Sequence[str], None] = None
Comment thread
dartpain marked this conversation as resolved.
Dismissed
depends_on: Union[str, Sequence[str], None] = None
Comment thread
dartpain marked this conversation as resolved.
Dismissed


def upgrade() -> None:
op.execute("ALTER TABLE token_usage ADD COLUMN model_id TEXT;")
op.execute(
'CREATE INDEX token_usage_model_ts_idx '
'ON token_usage (model_id, "timestamp" DESC) '
"WHERE model_id IS NOT NULL;"
)


def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS token_usage_model_ts_idx;")
op.execute("ALTER TABLE token_usage DROP COLUMN IF EXISTS model_id;")
1 change: 1 addition & 0 deletions application/api/user/scheduler_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def execute_scheduled_run_body(run_id: str, celery_task_id: Optional[str]) -> Di
agent_id=str(agent_id_raw) if agent_id_raw else None,
source="schedule",
request_id=str(run_id),
model_id=outcome.get("model_id"),
)
except Exception:
logger.exception(
Expand Down
6 changes: 5 additions & 1 deletion application/llm/llm_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def create_llm(

# Forward model_user_id so backup/fallback resolves under the
# owner's scope on shared-agent dispatch.
return plugin.llm_class(
llm = plugin.llm_class(
api_key,
user_api_key,
decoded_token=decoded_token,
Expand All @@ -124,3 +124,7 @@ def create_llm(
*args,
**kwargs,
)
# llm.model_id is the upstream name (BYOM resolves it above); stamp
# the canonical id (UUID for BYOM) separately for token_usage.
llm._canonical_model_id = model_id
return llm
3 changes: 3 additions & 0 deletions application/storage/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@
# N rows) count as a single request via DISTINCT in the repository
# query. NULL on side-channel sources by design.
Column("request_id", Text),
# Added in ``0015_token_usage_model_id``. Canonical model id (catalog
# name for built-ins, UUID for BYOM); NULL on un-backfilled rows.
Column("model_id", Text),
)

user_logs_table = Table(
Expand Down
6 changes: 4 additions & 2 deletions application/storage/db/repositories/token_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def insert(
generated_tokens: int = 0,
source: str = "agent_stream",
request_id: Optional[str] = None,
model_id: Optional[str] = None,
timestamp: Optional[datetime] = None,
) -> None:
# Attribution guard: the ``token_usage_attribution_chk`` CHECK
Expand All @@ -59,13 +60,13 @@ def insert(
INSERT INTO token_usage (
user_id, api_key, agent_id,
prompt_tokens, generated_tokens,
source, request_id, timestamp
source, request_id, model_id, timestamp
)
VALUES (
:user_id, :api_key,
CAST(:agent_id AS uuid),
:prompt_tokens, :generated_tokens,
:source, :request_id, COALESCE(:timestamp, now())
:source, :request_id, :model_id, COALESCE(:timestamp, now())
)
"""
),
Expand All @@ -77,6 +78,7 @@ def insert(
"generated_tokens": generated_tokens,
"source": source,
"request_id": request_id,
"model_id": model_id,
"timestamp": timestamp,
},
)
Expand Down
1 change: 1 addition & 0 deletions application/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _persist_call_usage(llm, call_usage):
getattr(llm, "_token_usage_source", None) or "agent_stream"
),
request_id=getattr(llm, "_request_id", None),
model_id=getattr(llm, "_canonical_model_id", None),
)
except Exception:
logger.exception("token_usage persist failed")
Expand Down
174 changes: 174 additions & 0 deletions scripts/db/backfill_token_usage_model_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Backfill ``token_usage.model_id`` for rows written before the column.

New rows get ``model_id`` stamped at write time (see
``application.llm.llm_creator`` / ``application.usage``). This script
fills the historical NULLs by deriving the model from data we already
trust, in priority order. A row is only ever filled by the
highest-priority tier that matches it; tiers run in one transaction so
each later tier sees only the rows still NULL.

Tiers (both touch only ``source='agent_stream'`` rows)
-----
1. ``request_id`` join (high confidence). The route stamps the same
``request_id`` on the token_usage row and the assistant message, so
``conversation_messages.model_id`` is authoritative for the call.
2. ``agent_id`` + nearest message (medium confidence). For primary rows
with no usable ``request_id`` (legacy), copy ``model_id`` from the
closest-in-time message of any conversation belonging to the same
agent, within ``--window-minutes`` (ties broken toward the later
message so re-runs are reproducible).

Side-channel rows (``fallback`` / ``compression`` / ``title`` /
``rag_condense`` / ``schedule``) are left NULL: they share the primary
turn's ``request_id`` or agent but often ran a *different* model (a
backup, a compression override), so copying the primary turn's model
onto them would mis-attribute spend. New rows already get the correct
per-call model stamped at write time, so this only concerns history.

Rows that match neither tier are left NULL on purpose — the partial
index ``token_usage_model_ts_idx`` excludes them, and a model we can't
tie to the specific call (e.g. the agent's configured default) would
poison the analytics it feeds.

Both ``model_id`` columns store the canonical id (catalog name for
built-ins, UUID for BYOM), so BYOM rows backfill to the UUID unchanged.

Usage::

# Dry-run (default): runs the fills in a rolled-back transaction and
# reports exactly how many rows each tier would touch.
python scripts/db/backfill_token_usage_model_id.py

# Commit the backfill.
python scripts/db/backfill_token_usage_model_id.py --apply

# Widen the tier-2 match window (default 5 minutes).
python scripts/db/backfill_token_usage_model_id.py --window-minutes 10 --apply

Exit codes:
0 — success (dry-run or apply)
1 — bad arguments
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[2]))

from sqlalchemy import text # noqa: E402

from application.storage.db.engine import get_engine # noqa: E402


# Tier 1: same request -> same model, primary (agent_stream) rows only.
# conversation_messages.model_id is authoritative for that turn; fallback
# / compression rows share the request_id but ran a different model.
_TIER1 = text(
"""
UPDATE token_usage tu
SET model_id = cm.model_id
FROM conversation_messages cm
WHERE cm.request_id = tu.request_id
AND cm.model_id IS NOT NULL
AND tu.model_id IS NULL
AND tu.request_id IS NOT NULL
AND tu.source = 'agent_stream'
"""
)

# Tier 2: nearest message of the same agent within the window, primary
# (agent_stream) rows only. The EXISTS mirror skips rows with no match
# (else the subquery would set NULL); the ORDER BY tiebreak (later message
# wins) keeps the pick reproducible across re-runs.
_TIER2 = text(
"""
UPDATE token_usage tu
SET model_id = (
SELECT cm.model_id
FROM conversation_messages cm
JOIN conversations c ON c.id = cm.conversation_id
WHERE c.agent_id = tu.agent_id
AND cm.model_id IS NOT NULL
AND cm.timestamp BETWEEN tu.timestamp - make_interval(mins => :win)
AND tu.timestamp + make_interval(mins => :win)
ORDER BY abs(extract(epoch FROM (cm.timestamp - tu.timestamp))), cm.timestamp DESC
LIMIT 1
)
WHERE tu.model_id IS NULL
AND tu.agent_id IS NOT NULL
AND tu.source = 'agent_stream'
AND EXISTS (
SELECT 1
FROM conversation_messages cm
JOIN conversations c ON c.id = cm.conversation_id
WHERE c.agent_id = tu.agent_id
AND cm.model_id IS NOT NULL
AND cm.timestamp BETWEEN tu.timestamp - make_interval(mins => :win)
AND tu.timestamp + make_interval(mins => :win)
)
"""
)

_COUNT_NULL = text("SELECT count(*) FROM token_usage WHERE model_id IS NULL")


def main() -> int:
parser = argparse.ArgumentParser(
description="Backfill token_usage.model_id from existing data.",
)
parser.add_argument(
"--apply",
action="store_true",
help="Commit the backfill. Default is a rolled-back dry-run.",
)
parser.add_argument(
"--window-minutes",
type=int,
default=5,
metavar="N",
help="Tier-2 nearest-message match window, in minutes (default 5).",
)
args = parser.parse_args()

if args.window_minutes < 0:
print("--window-minutes must be >= 0", file=sys.stderr)
return 1

engine = get_engine()
with engine.connect() as conn:
trans = conn.begin()
try:
# A one-shot maintenance UPDATE can run well past the engine's
# 30s per-statement guardrail; lift it for this transaction.
conn.execute(text("SET LOCAL statement_timeout = 0"))

before = conn.execute(_COUNT_NULL).scalar_one()

t1 = conn.execute(_TIER1).rowcount or 0
t2 = conn.execute(_TIER2, {"win": args.window_minutes}).rowcount or 0

after = conn.execute(_COUNT_NULL).scalar_one()

print(f"NULL model_id rows before: {before}")
print(f" tier 1 (request_id): {t1}")
print(f" tier 2 (agent + nearest msg): {t2}")
print(f"NULL model_id rows remaining: {after}")

if args.apply:
trans.commit()
print("\nCommitted.")
else:
trans.rollback()
print("\nDry run — rolled back. Re-run with --apply to commit.")
except Exception:
trans.rollback()
raise

return 0


if __name__ == "__main__":
sys.exit(main())
28 changes: 28 additions & 0 deletions tests/storage/db/repositories/test_token_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timedelta, timezone

import pytest
from sqlalchemy import text

from application.storage.db.repositories.token_usage import TokenUsageRepository

Expand Down Expand Up @@ -35,6 +36,33 @@ def test_insert_with_api_key(self, pg_conn):
assert total == 30


class TestModelId:
def _latest_model_id(self, conn, user_id):
return conn.execute(
text(
"SELECT model_id FROM token_usage "
"WHERE user_id = :u ORDER BY id DESC LIMIT 1"
),
{"u": user_id},
).scalar_one()

def test_persists_model_id(self, pg_conn):
repo = _repo(pg_conn)
repo.insert(user_id="u-model", prompt_tokens=1, generated_tokens=1, model_id="gpt-4o")
assert self._latest_model_id(pg_conn, "u-model") == "gpt-4o"

def test_persists_byom_uuid_model_id(self, pg_conn):
repo = _repo(pg_conn)
uuid_id = "11111111-1111-1111-1111-111111111111"
repo.insert(user_id="u-byom", prompt_tokens=1, generated_tokens=1, model_id=uuid_id)
assert self._latest_model_id(pg_conn, "u-byom") == uuid_id

def test_model_id_defaults_to_null(self, pg_conn):
repo = _repo(pg_conn)
repo.insert(user_id="u-nomodel", prompt_tokens=1, generated_tokens=1)
assert self._latest_model_id(pg_conn, "u-nomodel") is None


class TestSumTokensInRange:
def test_sums_correctly(self, pg_conn):
repo = _repo(pg_conn)
Expand Down
Loading