diff --git a/backend/main.py b/backend/main.py index a86e23b76..8df5250d2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -80,6 +80,7 @@ def _validate_mlflow_experiment() -> bool: watch_spaces_router, watch_usage_router, ) +from backend.watch.services.system_tables import warm_cost_overview_cache class OBOAuthMiddleware(BaseHTTPMiddleware): @@ -195,6 +196,10 @@ async def startup(): except Exception: logger.warning("iterations schema probe failed", exc_info=True) + # Pre-warm the Cost-tab overview cache in the background since it + # takes a long time ro run. + asyncio.create_task(asyncio.to_thread(warm_cost_overview_cache, 7)) + @app.on_event("shutdown") async def shutdown(): diff --git a/backend/watch/routers/cost.py b/backend/watch/routers/cost.py index 70ba72585..4f509d823 100644 --- a/backend/watch/routers/cost.py +++ b/backend/watch/routers/cost.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from collections import defaultdict from typing import Optional @@ -23,7 +24,7 @@ @router.get("/overview") -async def workspace_overview(days: int = Query(7, ge=1, le=365)) -> dict: +def workspace_overview(days: int = Query(7, ge=1, le=365)) -> dict: """Workspace-wide KPIs + daily query volume for the native cost-tab overview.""" days = validate_days(days, default=7) summary = system_tables.workspace_summary(days=days) @@ -45,7 +46,7 @@ async def workspace_overview(days: int = Query(7, ge=1, le=365)) -> dict: @router.get("/spaces/{space_id}/cost") -async def get_space_cost(space_id: str, days: int = Query(7, ge=1, le=365)) -> dict: +def get_space_cost(space_id: str, days: int = Query(7, ge=1, le=365)) -> dict: sid = validate_space_id(space_id) days = validate_days(days, default=7) @@ -99,7 +100,7 @@ async def get_space_cost(space_id: str, days: int = Query(7, ge=1, le=365)) -> d @router.get("/cost/top") async def top_spenders(days: int = Query(7, ge=1, le=365), limit: int = Query(10, ge=1, le=200)) -> list[dict]: days = validate_days(days, default=7) - rows = system_tables.top_spenders(days=days, limit=limit) + rows = await asyncio.to_thread(system_tables.top_spenders, days=days, limit=limit) # Genie space titles aren't in the system tables; resolve them from the # space cache (same source SpacesList uses). Best-effort: a missing cache @@ -128,7 +129,7 @@ async def top_spenders(days: int = Query(7, ge=1, le=365), limit: int = Query(10 @router.get("/spaces/{space_id}/cost/top-queries") -async def top_expensive_queries( +def top_expensive_queries( space_id: str, days: int = Query(7, ge=1, le=365), limit: int = Query(20, ge=1, le=100), @@ -139,7 +140,7 @@ async def top_expensive_queries( @router.get("/spaces/{space_id}/cost/conversations") -async def cost_per_conversation( +def cost_per_conversation( space_id: str, days: int = Query(7, ge=1, le=365), limit: int = Query(50, ge=1, le=500), diff --git a/backend/watch/services/system_tables.py b/backend/watch/services/system_tables.py index d37ffa581..4582340bc 100644 --- a/backend/watch/services/system_tables.py +++ b/backend/watch/services/system_tables.py @@ -32,6 +32,7 @@ # ─── In-process TTL cache ───────────────────────────────────────────────── _CACHE_TTL_SECONDS = 300 +_LONG_CACHE_TTL_SECONDS = 1800 _CACHE_MAX = 256 _CACHE_LOCK = threading.Lock() _CACHE: dict[str, tuple[float, list[dict[str, Any]]]] = {} @@ -70,13 +71,13 @@ def _cache_key(sql: str, parameters: list[StatementParameterListItem]) -> str: return f"{hash(sql)}|{json.dumps(bag)}" -def _cache_get(key: str) -> list[dict[str, Any]] | None: +def _cache_get(key: str, ttl_seconds: int = _CACHE_TTL_SECONDS) -> list[dict[str, Any]] | None: with _CACHE_LOCK: entry = _CACHE.get(key) if not entry: return None ts, rows = entry - if time.monotonic() - ts > _CACHE_TTL_SECONDS: + if time.monotonic() - ts > ttl_seconds: _CACHE.pop(key, None) return None return rows @@ -90,6 +91,20 @@ def _cache_put(key: str, rows: list[dict[str, Any]]) -> None: _CACHE[key] = (time.monotonic(), rows) +def warm_cost_overview_cache(days: int = 7) -> None: + """Pre-run the Cost-tab overview queries so the cache is hot before users + hit the page. Call from a background task.""" + for fn, kwargs in ( + (workspace_summary, {"days": days}), + (daily_volume_all_spaces, {"days": days}), + (top_spenders, {"days": days, "limit": 10}), + ): + try: + fn(**kwargs) + except Exception: + logger.warning("warmup of %s failed", fn.__name__, exc_info=True) + + def _warehouse_id() -> str: wh = os.environ.get("SQL_WAREHOUSE_ID", "").strip() if not wh: @@ -107,10 +122,11 @@ def _run( poll_total_seconds: int = 180, poll_interval_seconds: float = 2.0, track_health: bool = True, + ttl_seconds: int = _CACHE_TTL_SECONDS, ) -> list[dict[str, Any]]: global _SYSTEM_TABLES_ACCESSIBLE key = _cache_key(sql, parameters) - cached = _cache_get(key) + cached = _cache_get(key, ttl_seconds=ttl_seconds) if cached is not None: return cached @@ -363,7 +379,7 @@ def _workspace_names(workspace_ids: set[str]) -> dict[str, str]: def top_spenders(days: int = 7, limit: int = 10) -> list[dict[str, Any]]: params = [_p("days", days, "INT"), _p("limit", limit, "INT")] sql = _TOP_SPENDERS_SQL.format(ws=_ws_clause(params)) - rows = _run(sql, params) + rows = _run(sql, params, ttl_seconds=_LONG_CACHE_TTL_SECONDS) workspace_ids = {r.get("workspace_id") for r in rows if r.get("workspace_id")} names = _workspace_names(workspace_ids) for r in rows: @@ -579,7 +595,7 @@ def workspace_summary(days: int = 7) -> dict[str, Any]: # only a single ws_id param is appended. params = [_p("days", days, "INT")] sql = _WORKSPACE_SUMMARY_SQL.format(ws=_ws_clause(params)) - rows = _run(sql, params) + rows = _run(sql, params, ttl_seconds=_LONG_CACHE_TTL_SECONDS) return rows[0] if rows else {} @@ -610,7 +626,7 @@ def workspace_summary(days: int = 7) -> dict[str, Any]: def daily_volume_all_spaces(days: int = 30) -> list[dict[str, Any]]: params = [_p("days", days, "INT")] sql = _DAILY_VOLUME_ALL_SQL.format(ws=_ws_clause(params)) - return _run(sql, params) + return _run(sql, params, ttl_seconds=_LONG_CACHE_TTL_SECONDS) _TOP_QUERIES_SQL = """