Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _validate_mlflow_experiment() -> bool:
watch_spaces_router,
watch_usage_router,
)
from backend.watch.services.system_tables import warm_cost_overview_cache


class OBOAuthMiddleware(BaseHTTPMiddleware):
Expand Down Expand Up @@ -195,6 +196,10 @@ async def startup():
except Exception:
logger.warning("iterations schema probe failed", exc_info=True)

# Pre-warm the Cost-tab overview cache in the background since it
# takes a long time ro run.
asyncio.create_task(asyncio.to_thread(warm_cost_overview_cache, 7))


@app.on_event("shutdown")
async def shutdown():
Expand Down
11 changes: 6 additions & 5 deletions backend/watch/routers/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
from collections import defaultdict
from typing import Optional

Expand All @@ -23,7 +24,7 @@


@router.get("/overview")
async def workspace_overview(days: int = Query(7, ge=1, le=365)) -> dict:
def workspace_overview(days: int = Query(7, ge=1, le=365)) -> dict:
"""Workspace-wide KPIs + daily query volume for the native cost-tab overview."""
days = validate_days(days, default=7)
summary = system_tables.workspace_summary(days=days)
Expand All @@ -45,7 +46,7 @@ async def workspace_overview(days: int = Query(7, ge=1, le=365)) -> dict:


@router.get("/spaces/{space_id}/cost")
async def get_space_cost(space_id: str, days: int = Query(7, ge=1, le=365)) -> dict:
def get_space_cost(space_id: str, days: int = Query(7, ge=1, le=365)) -> dict:
sid = validate_space_id(space_id)
days = validate_days(days, default=7)

Expand Down Expand Up @@ -99,7 +100,7 @@ async def get_space_cost(space_id: str, days: int = Query(7, ge=1, le=365)) -> d
@router.get("/cost/top")
async def top_spenders(days: int = Query(7, ge=1, le=365), limit: int = Query(10, ge=1, le=200)) -> list[dict]:
days = validate_days(days, default=7)
rows = system_tables.top_spenders(days=days, limit=limit)
rows = await asyncio.to_thread(system_tables.top_spenders, days=days, limit=limit)

# Genie space titles aren't in the system tables; resolve them from the
# space cache (same source SpacesList uses). Best-effort: a missing cache
Expand Down Expand Up @@ -128,7 +129,7 @@ async def top_spenders(days: int = Query(7, ge=1, le=365), limit: int = Query(10


@router.get("/spaces/{space_id}/cost/top-queries")
async def top_expensive_queries(
def top_expensive_queries(
space_id: str,
days: int = Query(7, ge=1, le=365),
limit: int = Query(20, ge=1, le=100),
Expand All @@ -139,7 +140,7 @@ async def top_expensive_queries(


@router.get("/spaces/{space_id}/cost/conversations")
async def cost_per_conversation(
def cost_per_conversation(
space_id: str,
days: int = Query(7, ge=1, le=365),
limit: int = Query(50, ge=1, le=500),
Expand Down
28 changes: 22 additions & 6 deletions backend/watch/services/system_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

# ─── In-process TTL cache ─────────────────────────────────────────────────
_CACHE_TTL_SECONDS = 300
_LONG_CACHE_TTL_SECONDS = 1800
_CACHE_MAX = 256
_CACHE_LOCK = threading.Lock()
_CACHE: dict[str, tuple[float, list[dict[str, Any]]]] = {}
Expand Down Expand Up @@ -70,13 +71,13 @@ def _cache_key(sql: str, parameters: list[StatementParameterListItem]) -> str:
return f"{hash(sql)}|{json.dumps(bag)}"


def _cache_get(key: str) -> list[dict[str, Any]] | None:
def _cache_get(key: str, ttl_seconds: int = _CACHE_TTL_SECONDS) -> list[dict[str, Any]] | None:
with _CACHE_LOCK:
entry = _CACHE.get(key)
if not entry:
return None
ts, rows = entry
if time.monotonic() - ts > _CACHE_TTL_SECONDS:
if time.monotonic() - ts > ttl_seconds:
_CACHE.pop(key, None)
return None
return rows
Expand All @@ -90,6 +91,20 @@ def _cache_put(key: str, rows: list[dict[str, Any]]) -> None:
_CACHE[key] = (time.monotonic(), rows)


def warm_cost_overview_cache(days: int = 7) -> None:
"""Pre-run the Cost-tab overview queries so the cache is hot before users
hit the page. Call from a background task."""
for fn, kwargs in (
(workspace_summary, {"days": days}),
(daily_volume_all_spaces, {"days": days}),
(top_spenders, {"days": days, "limit": 10}),
):
try:
fn(**kwargs)
except Exception:
logger.warning("warmup of %s failed", fn.__name__, exc_info=True)


def _warehouse_id() -> str:
wh = os.environ.get("SQL_WAREHOUSE_ID", "").strip()
if not wh:
Expand All @@ -107,10 +122,11 @@ def _run(
poll_total_seconds: int = 180,
poll_interval_seconds: float = 2.0,
track_health: bool = True,
ttl_seconds: int = _CACHE_TTL_SECONDS,
) -> list[dict[str, Any]]:
global _SYSTEM_TABLES_ACCESSIBLE
key = _cache_key(sql, parameters)
cached = _cache_get(key)
cached = _cache_get(key, ttl_seconds=ttl_seconds)
if cached is not None:
return cached

Expand Down Expand Up @@ -363,7 +379,7 @@ def _workspace_names(workspace_ids: set[str]) -> dict[str, str]:
def top_spenders(days: int = 7, limit: int = 10) -> list[dict[str, Any]]:
params = [_p("days", days, "INT"), _p("limit", limit, "INT")]
sql = _TOP_SPENDERS_SQL.format(ws=_ws_clause(params))
rows = _run(sql, params)
rows = _run(sql, params, ttl_seconds=_LONG_CACHE_TTL_SECONDS)
workspace_ids = {r.get("workspace_id") for r in rows if r.get("workspace_id")}
names = _workspace_names(workspace_ids)
for r in rows:
Expand Down Expand Up @@ -579,7 +595,7 @@ def workspace_summary(days: int = 7) -> dict[str, Any]:
# only a single ws_id param is appended.
params = [_p("days", days, "INT")]
sql = _WORKSPACE_SUMMARY_SQL.format(ws=_ws_clause(params))
rows = _run(sql, params)
rows = _run(sql, params, ttl_seconds=_LONG_CACHE_TTL_SECONDS)
return rows[0] if rows else {}


Expand Down Expand Up @@ -610,7 +626,7 @@ def workspace_summary(days: int = 7) -> dict[str, Any]:
def daily_volume_all_spaces(days: int = 30) -> list[dict[str, Any]]:
params = [_p("days", days, "INT")]
sql = _DAILY_VOLUME_ALL_SQL.format(ws=_ws_clause(params))
return _run(sql, params)
return _run(sql, params, ttl_seconds=_LONG_CACHE_TTL_SECONDS)


_TOP_QUERIES_SQL = """
Expand Down