diff --git a/api/main.py b/api/main.py index 39006f9..635c806 100644 --- a/api/main.py +++ b/api/main.py @@ -24,6 +24,7 @@ from api.entity_routes import router as entity_router from api.exceptions import AppError from api.middleware import ContentTypeValidationMiddleware, LoggingMiddleware, RequestIDMiddleware +from api.policy_routes import router as policy_router from api.replay_routes import router as replay_router from api.search_routes import router as search_router from api.session_routes import router as session_router @@ -133,6 +134,7 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(cost_router) app.include_router(search_router) app.include_router(entity_router) + app.include_router(policy_router) app.include_router(system_router) app.include_router(ui_router) diff --git a/api/policy_routes.py b/api/policy_routes.py new file mode 100644 index 0000000..c3a5e5e --- /dev/null +++ b/api/policy_routes.py @@ -0,0 +1,205 @@ +"""Alert policy API routes for configurable alert thresholds.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Query +from sqlalchemy.ext.asyncio import AsyncSession + +from api.dependencies import get_db_session, get_tenant_id +from api.exceptions import NotFoundError +from api.schemas import AlertPolicyCreate, AlertPolicyListResponse, AlertPolicySchema, AlertPolicyUpdate +from storage import AlertPolicyRepository + +router = APIRouter(tags=["alert-policies"]) + + +async def get_policy_repository( + session: AsyncSession = Depends(get_db_session), + tenant_id: str = Depends(get_tenant_id), +) -> AlertPolicyRepository: + """Get an alert policy repository scoped to the current tenant.""" + return AlertPolicyRepository(session, tenant_id=tenant_id) + + +@router.get("/api/alert-policies", response_model=AlertPolicyListResponse) +async def list_policies( + agent_name: str | None = Query(default=None), + limit: int = Query(default=100, ge=1, le=1000), + repo: AlertPolicyRepository = Depends(get_policy_repository), +) -> AlertPolicyListResponse: + """List all alert policies, optionally filtered by agent_name. + + Args: + agent_name: Optional agent name filter. If provided, returns both + agent-specific and global policies for this agent. + limit: Maximum number of policies to return + repo: AlertPolicyRepository instance + + Returns: + List of alert policies + """ + policies = await repo.list_policies(agent_name=agent_name, limit=limit) + + return AlertPolicyListResponse( + policies=[ + AlertPolicySchema( + id=policy.id, + agent_name=policy.agent_name, + alert_type=policy.alert_type, + threshold_value=policy.threshold_value, + severity_threshold=policy.severity_threshold, + enabled=policy.enabled, + created_at=policy.created_at, + updated_at=policy.updated_at, + ) + for policy in policies + ], + total=len(policies), + ) + + +@router.post("/api/alert-policies", response_model=AlertPolicySchema) +async def create_policy( + data: AlertPolicyCreate, + repo: AlertPolicyRepository = Depends(get_policy_repository), +) -> AlertPolicySchema: + """Create a new alert policy. + + Args: + data: Policy creation data + repo: AlertPolicyRepository instance + + Returns: + Created alert policy + """ + policy = await repo.create_policy( + agent_name=data.agent_name, + alert_type=data.alert_type, + threshold_value=data.threshold_value, + severity_threshold=data.severity_threshold, + enabled=data.enabled, + ) + # Commit to persist the policy + await repo.session.commit() + await repo.session.refresh(policy) + + return AlertPolicySchema( + id=policy.id, + agent_name=policy.agent_name, + alert_type=policy.alert_type, + threshold_value=policy.threshold_value, + severity_threshold=policy.severity_threshold, + enabled=policy.enabled, + created_at=policy.created_at, + updated_at=policy.updated_at, + ) + + +@router.get("/api/alert-policies/{policy_id}", response_model=AlertPolicySchema) +async def get_policy( + policy_id: str, + repo: AlertPolicyRepository = Depends(get_policy_repository), +) -> AlertPolicySchema: + """Get a single alert policy by ID. + + Args: + policy_id: Unique identifier of the policy + repo: AlertPolicyRepository instance + + Returns: + Alert policy details + + Raises: + NotFoundError: if policy not found + """ + policy = await repo.get_policy(policy_id) + if not policy: + raise NotFoundError(f"Policy {policy_id} not found") + + return AlertPolicySchema( + id=policy.id, + agent_name=policy.agent_name, + alert_type=policy.alert_type, + threshold_value=policy.threshold_value, + severity_threshold=policy.severity_threshold, + enabled=policy.enabled, + created_at=policy.created_at, + updated_at=policy.updated_at, + ) + + +@router.put("/api/alert-policies/{policy_id}", response_model=AlertPolicySchema) +async def update_policy( + policy_id: str, + data: AlertPolicyUpdate, + repo: AlertPolicyRepository = Depends(get_policy_repository), +) -> AlertPolicySchema: + """Update an existing alert policy. + + Args: + policy_id: Unique identifier of the policy to update + data: Policy update data + repo: AlertPolicyRepository instance + + Returns: + Updated alert policy + + Raises: + NotFoundError: if policy not found + """ + policy = await repo.update_policy( + policy_id=policy_id, + agent_name=data.agent_name, + alert_type=data.alert_type, + threshold_value=data.threshold_value, + severity_threshold=data.severity_threshold, + enabled=data.enabled, + ) + + if not policy: + raise NotFoundError(f"Policy {policy_id} not found") + + # Commit to persist changes + await repo.session.commit() + await repo.session.refresh(policy) + + return AlertPolicySchema( + id=policy.id, + agent_name=policy.agent_name, + alert_type=policy.alert_type, + threshold_value=policy.threshold_value, + severity_threshold=policy.severity_threshold, + enabled=policy.enabled, + created_at=policy.created_at, + updated_at=policy.updated_at, + ) + + +@router.delete("/api/alert-policies/{policy_id}") +async def delete_policy( + policy_id: str, + repo: AlertPolicyRepository = Depends(get_policy_repository), +) -> dict[str, Any]: + """Delete an alert policy by ID. + + Args: + policy_id: Unique identifier of the policy to delete + repo: AlertPolicyRepository instance + + Returns: + Deletion confirmation + + Raises: + NotFoundError: if policy not found + """ + deleted = await repo.delete_policy(policy_id) + + if not deleted: + raise NotFoundError(f"Policy {policy_id} not found") + + # Commit to persist deletion + await repo.session.commit() + + return {"deleted": True, "policy_id": policy_id} diff --git a/api/schemas.py b/api/schemas.py index 910d3f2..f503b97 100644 --- a/api/schemas.py +++ b/api/schemas.py @@ -300,6 +300,11 @@ class AnomalyAlertSchema(BaseModel): detection_source: str detection_config: dict[str, Any] created_at: datetime + status: str | None = None + acknowledged_at: datetime | None = None + resolved_at: datetime | None = None + dismissed_at: datetime | None = None + resolution_note: str | None = None class AnomalyAlertListResponse(BaseModel): @@ -310,6 +315,77 @@ class AnomalyAlertListResponse(BaseModel): total: int +# ------------------------------------------------------------------ +# Alert Lifecycle Schemas +# ------------------------------------------------------------------ + + +class AlertStatusUpdate(BaseModel): + """Request schema for updating a single alert's status.""" + + status: str = Field(min_length=1, max_length=32) + note: str | None = Field(default=None, max_length=2000) + + +class AlertBulkUpdate(BaseModel): + """Request schema for bulk updating alert statuses.""" + + alert_ids: list[str] = Field(min_length=1) + status: str = Field(min_length=1, max_length=32) + + +class AlertFilters(BaseModel): + """Query parameters for filtering alerts.""" + + agent_name: str | None = None + severity: float | None = Field(default=None, ge=0.0, le=1.0) + alert_type: str | None = None + status: str | None = None + from_date: datetime | None = None + to_date: datetime | None = None + limit: int = Field(default=50, ge=1, le=500) + + +class AlertSeverityCount(BaseModel): + """Count of alerts by severity level.""" + + critical: int + high: int + medium: int + low: int + + +class AlertSummarySchema(BaseModel): + """Alert summary statistics.""" + + by_status: dict[str, int] + by_type: dict[str, int] + by_severity: AlertSeverityCount + total: int + + +class AlertTrendingPointSchema(BaseModel): + """Single data point for alert trending.""" + + date: str + count: int + + +class AlertTrendingSchema(BaseModel): + """Alert volume over time.""" + + trending: list[AlertTrendingPointSchema] + days: int + + +class AlertListFilteredResponse(BaseModel): + """Response schema for filtered alert listing.""" + + alerts: list[AnomalyAlertSchema] + total: int + filters: AlertFilters + + class FixNoteRequest(BaseModel): """Request schema for adding/updating a fix note.""" @@ -405,3 +481,48 @@ class SimilarFailuresResponse(BaseModel): failure_event_id: str similar_failures: list[SimilarFailureSchema] total: int + + +# ------------------------------------------------------------------ +# Alert policy schemas +# ------------------------------------------------------------------ + + +class AlertPolicyCreate(BaseModel): + """Request schema for creating an alert policy.""" + + agent_name: str | None = Field(default=None, max_length=255) + alert_type: str = Field(min_length=1, max_length=64) + threshold_value: float = Field(ge=0.0) + severity_threshold: str | None = Field(default=None, max_length=16) + enabled: bool = Field(default=True) + + +class AlertPolicyUpdate(BaseModel): + """Request schema for updating an alert policy.""" + + agent_name: str | None = Field(default=None, max_length=255) + alert_type: str | None = Field(default=None, min_length=1, max_length=64) + threshold_value: float | None = Field(default=None, ge=0.0) + severity_threshold: str | None = Field(default=None, max_length=16) + enabled: bool | None = None + + +class AlertPolicySchema(BaseModel): + """Response schema for alert policies.""" + + id: str + agent_name: str | None + alert_type: str + threshold_value: float + severity_threshold: str | None + enabled: bool + created_at: datetime + updated_at: datetime + + +class AlertPolicyListResponse(BaseModel): + """Response schema for listing alert policies.""" + + policies: list[AlertPolicySchema] + total: int diff --git a/api/trace_routes.py b/api/trace_routes.py index 4305a2c..7d66782 100644 --- a/api/trace_routes.py +++ b/api/trace_routes.py @@ -19,6 +19,14 @@ from api.exceptions import NotFoundError from api.schemas import ( AgentBaselineSchema, + AlertBulkUpdate, + AlertFilters, + AlertListFilteredResponse, + AlertSeverityCount, + AlertStatusUpdate, + AlertSummarySchema, + AlertTrendingPointSchema, + AlertTrendingSchema, AnalysisResponse, AnomalyAlertListResponse, AnomalyAlertSchema, @@ -315,6 +323,11 @@ async def get_session_alerts( detection_source=alert.detection_source, detection_config=alert.detection_config or {}, created_at=alert.created_at, + status=alert.status, + acknowledged_at=alert.acknowledged_at, + resolved_at=alert.resolved_at, + dismissed_at=alert.dismissed_at, + resolution_note=alert.resolution_note, ) for alert in alerts ], @@ -322,6 +335,157 @@ async def get_session_alerts( ) +# ------------------------------------------------------------------ +# Alert Lifecycle Management Endpoints (must come before /api/alerts/{alert_id}) +# ------------------------------------------------------------------ + + +@router.get("/api/alerts/summary", response_model=AlertSummarySchema) +async def get_alert_summary( + repo: TraceRepository = Depends(get_repository), +) -> AlertSummarySchema: + """Get alert summary statistics grouped by severity, type, and status. + + Args: + repo: TraceRepository instance + + Returns: + AlertSummarySchema with counts by severity, type, and status + """ + summary = await repo._alert_repo.get_alert_lifecycle_summary() + + return AlertSummarySchema( + by_status=summary["by_status"], + by_type=summary["by_type"], + by_severity=AlertSeverityCount(**summary["by_severity"]), + total=summary["total"], + ) + + +@router.get("/api/alerts/trending", response_model=AlertTrendingSchema) +async def get_alert_trending( + days: int = Query(default=7, ge=1, le=90), + repo: TraceRepository = Depends(get_repository), +) -> AlertTrendingSchema: + """Get alert volume trend grouped by day. + + Args: + days: Number of days to look back (default 7, max 90) + repo: TraceRepository instance + + Returns: + AlertTrendingSchema with list of daily counts + """ + trending = await repo._alert_repo.get_alert_trending(days=days) + + return AlertTrendingSchema( + trending=[AlertTrendingPointSchema(**point) for point in trending], + days=days, + ) + + +@router.post("/api/alerts/bulk-status") +async def bulk_update_alert_status( + update: AlertBulkUpdate, + repo: TraceRepository = Depends(get_repository), +) -> dict[str, Any]: + """Bulk update status for multiple alerts. + + Args: + update: Bulk update request with alert_ids and status + repo: TraceRepository instance + + Returns: + Dictionary with updated count + + Raises: + ValueError: if status is invalid + """ + updated_count = await repo._alert_repo.bulk_update_status(update.alert_ids, update.status) + + try: + await repo.commit() + except Exception: + await repo.rollback() + raise + + return { + "updated": updated_count, + "status": update.status, + } + + +@router.get("/api/alerts", response_model=AlertListFilteredResponse) +async def list_alerts_filtered( + agent_name: str | None = Query(default=None), + severity: float | None = Query(default=None, ge=0.0, le=1.0), + alert_type: str | None = Query(default=None), + status: str | None = Query(default=None), + from_date: datetime | None = Query(default=None), + to_date: datetime | None = Query(default=None), + limit: int = Query(default=50, ge=1, le=500), + repo: TraceRepository = Depends(get_repository), +) -> AlertListFilteredResponse: + """List alerts with rich filtering options. + + Args: + agent_name: Optional agent name to filter by + severity: Optional minimum severity to filter by + alert_type: Optional alert type to filter by + status: Optional status to filter by + from_date: Optional start date for created_at filter + to_date: Optional end date for created_at filter + limit: Maximum number of alerts to return + repo: TraceRepository instance + + Returns: + AlertListFilteredResponse with filtered alerts + """ + alerts = await repo._alert_repo.list_alerts_filtered( + agent_name=agent_name, + severity=severity, + alert_type=alert_type, + status=status, + from_date=from_date, + to_date=to_date, + limit=limit, + ) + + filters = AlertFilters( + agent_name=agent_name, + severity=severity, + alert_type=alert_type, + status=status, + from_date=from_date, + to_date=to_date, + limit=limit, + ) + + return AlertListFilteredResponse( + alerts=[ + AnomalyAlertSchema( + id=alert.id, + session_id=alert.session_id, + alert_type=alert.alert_type, + severity=alert.severity, + signal=alert.signal, + event_ids=alert.event_ids or [], + detection_source=alert.detection_source, + detection_config=alert.detection_config or {}, + created_at=alert.created_at, + status=alert.status, + acknowledged_at=alert.acknowledged_at, + resolved_at=alert.resolved_at, + dismissed_at=alert.dismissed_at, + resolution_note=alert.resolution_note, + ) + for alert in alerts + ], + total=len(alerts), + filters=filters, + ) + + @router.get("/api/alerts/{alert_id}", response_model=AnomalyAlertSchema) async def get_alert( alert_id: str, @@ -353,4 +517,57 @@ async def get_alert( detection_source=alert.detection_source, detection_config=alert.detection_config or {}, created_at=alert.created_at, + status=alert.status, + acknowledged_at=alert.acknowledged_at, + resolved_at=alert.resolved_at, + dismissed_at=alert.dismissed_at, + resolution_note=alert.resolution_note, + ) + + +@router.put("/api/alerts/{alert_id}/status", response_model=AnomalyAlertSchema) +async def update_alert_status( + alert_id: str, + update: AlertStatusUpdate, + repo: TraceRepository = Depends(get_repository), +) -> AnomalyAlertSchema: + """Update the status of a single alert. + + Args: + alert_id: Unique identifier of the alert + update: Status update request with status and optional note + repo: TraceRepository instance + + Returns: + Updated AnomalyAlertSchema + + Raises: + NotFoundError: if alert not found + ValueError: if status is invalid + """ + alert = await repo._alert_repo.update_alert_status(alert_id, update.status, update.note) + if not alert: + raise NotFoundError(f"Alert {alert_id} not found") + + try: + await repo.commit() + except Exception: + await repo.rollback() + raise + + return AnomalyAlertSchema( + id=alert.id, + session_id=alert.session_id, + alert_type=alert.alert_type, + severity=alert.severity, + signal=alert.signal, + event_ids=alert.event_ids or [], + detection_source=alert.detection_source, + detection_config=alert.detection_config or {}, + created_at=alert.created_at, + status=alert.status, + acknowledged_at=alert.acknowledged_at, + resolved_at=alert.resolved_at, + dismissed_at=alert.dismissed_at, + resolution_note=alert.resolution_note, ) diff --git a/collector/alerts/base.py b/collector/alerts/base.py index 0f99f1d..7d1b684 100644 --- a/collector/alerts/base.py +++ b/collector/alerts/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod from typing import Any @@ -11,6 +12,67 @@ class AlertDeriver(ABC): """Protocol for deriving alerts from events.""" + def __init__(self, policy_getter: Any | None = None): + """Initialize the alerter with an optional policy getter. + + Args: + policy_getter: Optional callable that retrieves alert policies. + May be sync or async with signature: + (alert_type: str, agent_name: str | None) -> dict | None + """ + self.policy_getter = policy_getter + + def get_threshold( + self, + alert_type: str, + agent_name: str | None = None, + default_threshold: float = 0.0, + ) -> float: + """Get the threshold value for an alert type from policy or use default. + + Supports both sync and async policy getters. For async getters, returns + the default immediately — use ``get_threshold_async`` in async contexts. + + Args: + alert_type: Type of alert to get threshold for + agent_name: Optional agent name for agent-specific policies + default_threshold: Default threshold if no policy found + + Returns: + Threshold value to use + """ + if self.policy_getter: + policy = self.policy_getter(alert_type, agent_name) + if asyncio.iscoroutine(policy): + return default_threshold + if policy and policy.get("enabled", True): + return policy.get("threshold_value", default_threshold) + return default_threshold + + async def get_threshold_async( + self, + alert_type: str, + agent_name: str | None = None, + default_threshold: float = 0.0, + ) -> float: + """Async version of get_threshold that awaits async policy getters. + + Args: + alert_type: Type of alert to get threshold for + agent_name: Optional agent name for agent-specific policies + default_threshold: Default threshold if no policy found + + Returns: + Threshold value to use + """ + if self.policy_getter: + policy = self.policy_getter(alert_type, agent_name) + if asyncio.iscoroutine(policy): + policy = await policy + if policy and policy.get("enabled", True): + return policy.get("threshold_value", default_threshold) + return default_threshold + @abstractmethod def derive(self, events: list[TraceEvent]) -> list[dict[str, Any]]: """Derive alerts from a list of events. diff --git a/docs/assets/gifs/demo-failure-clustering.gif b/docs/assets/gifs/demo-failure-clustering.gif new file mode 100644 index 0000000..d5ff9b5 Binary files /dev/null and b/docs/assets/gifs/demo-failure-clustering.gif differ diff --git a/docs/assets/gifs/screenshots/capture_search.py b/docs/assets/gifs/screenshots/capture_search.py index 0fda6a8..8ec7aab 100644 --- a/docs/assets/gifs/screenshots/capture_search.py +++ b/docs/assets/gifs/screenshots/capture_search.py @@ -1,10 +1,12 @@ import asyncio -from playwright.async_api import async_playwright from pathlib import Path +from playwright.async_api import async_playwright + + async def capture_screenshots(): - screenshots_dir = Path("/home/nistrator/Documents/github/amplifier/ai_working/agent_debugger/docs/assets/gifs/screenshots") + screenshots_dir = Path(__file__).resolve().parent async with async_playwright() as p: browser = await p.chromium.launch(headless=False) @@ -20,7 +22,8 @@ async def capture_screenshots(): await page.screenshot(path=str(screenshots_dir / "search_01_initial.png")) # Find and click on search box - search_input = page.locator("input[placeholder*='search' i], input[aria-label*='search' i], .search input, #search").first + selector = "input[placeholder*='search' i], input[aria-label*='search' i], .search input, #search" + search_input = page.locator(selector).first if await search_input.count() > 0: await search_input.click() await asyncio.sleep(0.5) diff --git a/frontend/src/App.css b/frontend/src/App.css index c5286e9..28d9a4e 100644 --- a/frontend/src/App.css +++ b/frontend/src/App.css @@ -6180,3 +6180,544 @@ pre { scroll-behavior: auto !important; } } + +/* ═══════════════════════════════════════════════════════════ + ALERT DASHBOARD PANEL STYLES + ═══════════════════════════════════════════════════════════ */ + +.alert-dashboard { + padding: 1.4rem; +} + +.alert-dashboard .panel-head { + display: flex; + justify-content: space-between; + align-items: baseline; + padding: 1.3rem 1.5rem 1rem; + border-bottom: 1px solid color-mix(in oklch, var(--panel-border), transparent 40%); + margin-bottom: 1rem; + background: linear-gradient(180deg, color-mix(in oklch, var(--panel), var(--bg-strong) 8%) 0%, transparent 100%); + border-radius: 1.3rem 1.3rem 0 0; +} + +.alert-dashboard .panel-head .eyebrow { + margin: 0 0 0.3rem; + font-size: 0.68rem; + letter-spacing: 0.18em; + text-transform: uppercase; + color: var(--muted); + font-weight: 600; +} + +.alert-dashboard .panel-head h2 { + margin: 0; + font-size: 1.2rem; + font-weight: 700; + letter-spacing: -0.025em; +} + +/* Summary Cards */ +.alert-summary-cards { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(140px, 1fr)); + gap: 1rem; + margin-bottom: 1.5rem; +} + +.alert-card { + display: grid; + gap: 0.35rem; + padding: 1.25rem; + border-radius: 1.2rem; + background: color-mix(in oklch, var(--panel), var(--accent) 4%); + border: 1px solid color-mix(in oklch, var(--panel-border), var(--accent) 22%); + text-align: center; + transition: all var(--duration-normal) var(--ease-material); + box-shadow: 0 2px 8px color-mix(in oklch, var(--muted), transparent 85%); +} + +.alert-card:hover { + transform: translateY(-2px); + box-shadow: 0 4px 16px color-mix(in oklch, var(--muted), transparent 70%); +} + +.alert-card .metric-label { + font-size: 0.7rem; + text-transform: uppercase; + letter-spacing: 0.1em; + color: var(--muted); +} + +.alert-card strong { + font-size: 1.5rem; + font-weight: 700; + color: var(--text); +} + +.alert-card--critical { + background: color-mix(in oklch, var(--panel), var(--danger) 6%); + border-color: color-mix(in oklch, var(--panel-border), var(--danger) 25%); +} + +.alert-card--critical strong { + color: var(--danger); +} + +.alert-card--warning { + background: color-mix(in oklch, var(--panel), var(--warning) 6%); + border-color: color-mix(in oklch, var(--panel-border), var(--warning) 25%); +} + +.alert-card--warning strong { + color: var(--warning); +} + +.alert-card--info { + background: color-mix(in oklch, var(--panel), var(--accent) 6%); + border-color: color-mix(in oklch, var(--panel-border), var(--accent) 25%); +} + +.alert-card--info strong { + color: var(--accent); +} + +/* Filter Bar */ +.alert-filter-bar { + display: flex; + flex-wrap: wrap; + gap: 0.75rem; + align-items: center; + padding: 1rem; + border-radius: 1rem; + background: color-mix(in oklch, var(--panel), var(--bg-strong) 25%); + margin-bottom: 1rem; +} + +.alert-filter-bar .filter-select { + padding: 0.6rem 0.9rem; + border-radius: 0.8rem; + border: 1px solid var(--panel-border); + background: color-mix(in oklch, var(--panel), white 50%); + color: var(--text); + font-size: 0.85rem; + cursor: pointer; + transition: all var(--duration-normal) var(--ease-material); +} + +.alert-filter-bar .filter-select:focus { + outline: none; + border-color: var(--accent); + box-shadow: 0 0 0 3px color-mix(in oklch, var(--accent), transparent 80%); +} + +.alert-filter-bar .clear-filters-btn { + padding: 0.6rem 1rem; + border-radius: 0.8rem; + background: color-mix(in oklch, var(--muted), white 85%); + border: 1px solid var(--panel-border); + color: var(--text); + font-size: 0.8rem; + font-weight: 500; + cursor: pointer; + transition: all var(--duration-normal) var(--ease-material); +} + +.alert-filter-bar .clear-filters-btn:hover { + background: color-mix(in oklch, var(--muted), white 75%); + transform: translateY(-1px); +} + +.alert-filter-bar .bulk-acknowledge-btn { + padding: 0.6rem 1rem; + border-radius: 0.8rem; + background: color-mix(in oklch, var(--accent), white 85%); + border: 1px solid color-mix(in oklch, var(--accent), white 50%); + color: var(--accent-deep); + font-size: 0.8rem; + font-weight: 600; + cursor: pointer; + transition: all var(--duration-normal) var(--ease-material); +} + +.alert-filter-bar .bulk-acknowledge-btn:hover { + background: color-mix(in oklch, var(--accent), white 75%); + transform: translateY(-1px); +} + +/* Alert List */ +.alert-list { + display: grid; + gap: 0.75rem; + margin-bottom: 1.5rem; +} + +.alert-row { + position: relative; + padding: 1rem 1.1rem; + padding-left: 1.3rem; + border-radius: 1rem; + border: 1px solid var(--panel-border); + background: color-mix(in oklch, var(--panel), var(--bg-strong) 34%); + text-align: left; + cursor: pointer; + transition: all var(--duration-normal) var(--ease-material); + border-left: 4px solid transparent; +} + +.alert-row:hover { + transform: translateX(2px); + box-shadow: 0 4px 12px color-mix(in oklch, var(--muted), transparent 70%); +} + +.alert-row::before { + content: ''; + position: absolute; + left: 0.75rem; + top: 50%; + transform: translateY(-50%); + width: 6px; + height: 6px; + border-radius: 50%; + background: var(--muted); +} + +.alert-row--active { + border-left-color: var(--danger); + background: color-mix(in oklch, var(--panel), var(--danger) 8%); +} + +.alert-row--active::before { + background: var(--danger); +} + +.alert-row--acknowledged { + border-left-color: var(--warning); + background: color-mix(in oklch, var(--panel), var(--warning) 6%); +} + +.alert-row--acknowledged::before { + background: var(--warning); +} + +.alert-row--resolved { + border-left-color: var(--olive); + background: color-mix(in oklch, var(--panel), var(--olive) 6%); + opacity: 0.85; +} + +.alert-row--resolved::before { + background: var(--olive); +} + +.alert-row--dismissed { + border-left-color: var(--muted); + background: color-mix(in oklch, var(--panel), var(--muted) 6%); + opacity: 0.7; +} + +.alert-row--dismissed::before { + background: var(--muted); +} + +.alert-row-header { + display: flex; + justify-content: space-between; + align-items: center; + gap: 0.75rem; + margin-bottom: 0.5rem; +} + +.alert-row-meta { + display: flex; + align-items: center; + gap: 0.5rem; +} + +.alert-severity-dot { + width: 8px; + height: 8px; + border-radius: 50%; + flex-shrink: 0; + box-shadow: 0 0 0 2px color-mix(in oklch, currentColor, white 70%); +} + +.alert-type { + font-size: 0.8rem; + font-weight: 600; + color: var(--text); + text-transform: uppercase; + letter-spacing: 0.05em; +} + +.alert-severity { + padding: 0.2rem 0.5rem; + border-radius: 4px; + background: color-mix(in oklch, var(--panel), var(--bg-strong) 50%); + font-size: 0.7rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + color: var(--muted); +} + +.alert-status { + padding: 0.2rem 0.5rem; + border-radius: 4px; + background: color-mix(in oklch, var(--panel), var(--bg-strong) 50%); + font-size: 0.7rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; + color: var(--muted); +} + +.alert-time { + font-size: 0.75rem; + color: var(--muted); +} + +.alert-signal { + margin: 0.5rem 0; + font-size: 0.9rem; + color: var(--text); + line-height: 1.5; +} + +/* Alert Details */ +.alert-details { + margin-top: 1rem; + padding-top: 1rem; + border-top: 1px solid var(--panel-border); + display: grid; + gap: 0.75rem; +} + +.alert-detail-row { + display: flex; + justify-content: space-between; + align-items: center; + gap: 0.5rem; + font-size: 0.85rem; +} + +.alert-detail-label { + color: var(--muted); + font-size: 0.75rem; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; +} + +.alert-detail-value { + color: var(--text); + font-weight: 500; + text-align: right; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.alert-resolution { + display: grid; + gap: 0.5rem; + padding: 0.75rem; + border-radius: 0.75rem; + background: color-mix(in oklch, var(--panel), var(--bg-strong) 30%); + margin-top: 0.5rem; +} + +.alert-resolution label { + font-size: 0.75rem; + font-weight: 600; + color: var(--muted); + text-transform: uppercase; + letter-spacing: 0.05em; +} + +.alert-resolution .resolution-input { + width: 100%; + padding: 0.6rem 0.8rem; + border-radius: 0.6rem; + border: 1px solid var(--panel-border); + background: var(--panel); + font-size: 0.85rem; + resize: vertical; + transition: border-color 0.2s ease; +} + +.alert-resolution .resolution-input:focus { + outline: none; + border-color: var(--accent); + box-shadow: 0 0 0 3px color-mix(in oklch, var(--accent), transparent 80%); +} + +.alert-actions { + display: flex; + gap: 0.5rem; + flex-wrap: wrap; +} + +.alert-action-btn { + padding: 0.5rem 0.9rem; + border-radius: 0.6rem; + font-size: 0.8rem; + font-weight: 600; + cursor: pointer; + transition: all 160ms ease; + border: 1px solid transparent; +} + +.alert-action-btn:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.alert-action-btn--acknowledge { + background: color-mix(in oklch, var(--accent), white 85%); + border-color: color-mix(in oklch, var(--accent), white 50%); + color: var(--accent-deep); +} + +.alert-action-btn--acknowledge:hover:not(:disabled) { + background: color-mix(in oklch, var(--accent), white 75%); + transform: translateY(-1px); +} + +.alert-action-btn--resolve { + background: color-mix(in oklch, var(--olive), white 85%); + border-color: color-mix(in oklch, var(--olive), white 50%); + color: var(--olive); +} + +.alert-action-btn--resolve:hover:not(:disabled) { + background: color-mix(in oklch, var(--olive), white 75%); + transform: translateY(-1px); +} + +.alert-action-btn--dismiss { + background: color-mix(in oklch, var(--muted), white 85%); + border-color: color-mix(in oklch, var(--panel-border), var(--muted) 30%); + color: var(--muted); +} + +.alert-action-btn--dismiss:hover { + background: color-mix(in oklch, var(--muted), white 75%); + transform: translateY(-1px); +} + +/* Trending Section */ +.trending-section { + margin-top: 1.5rem; + padding: 1rem; + border-radius: 1rem; + background: color-mix(in oklch, var(--panel), var(--bg-strong) 20%); +} + +.trending-section h3 { + margin: 0 0 1rem; + font-size: 0.9rem; + text-transform: uppercase; + letter-spacing: 0.1em; + color: var(--muted); +} + +.trending-bars { + display: flex; + gap: 4px; + align-items: flex-end; + height: 120px; + padding: 0.5rem 0; +} + +.trending-bar-container { + flex: 1; + display: flex; + flex-direction: column; + align-items: center; + gap: 0.5rem; + height: 100%; +} + +.trending-bar { + width: 100%; + background: linear-gradient(180deg, var(--accent) 0%, color-mix(in oklch, var(--accent), var(--bg) 30%) 100%); + border-radius: 4px 4px 0 0; + transition: height 300ms ease; + min-height: 4px; + position: relative; +} + +.trending-bar:hover { + background: linear-gradient(180deg, var(--olive) 0%, color-mix(in oklch, var(--olive), var(--bg) 30%) 100%); +} + +.trending-label { + font-size: 0.65rem; + color: var(--muted); + text-align: center; + transform: rotate(-45deg); + transform-origin: center; + white-space: nowrap; + margin-top: 0.25rem; +} + +/* Loading and Empty States */ +.alert-dashboard .loading-state { + display: flex; + justify-content: center; + align-items: center; + padding: 3rem 2rem; + color: var(--muted); + font-size: 0.95rem; +} + +.alert-dashboard .empty-state { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + text-align: center; + padding: 3rem 2rem; + gap: 0.75rem; + min-height: 200px; +} + +/* Responsive Design */ +@media (max-width: 768px) { + .alert-summary-cards { + grid-template-columns: repeat(2, 1fr); + } + + .alert-filter-bar { + flex-direction: column; + align-items: stretch; + } + + .alert-filter-bar .filter-select, + .alert-filter-bar .clear-filters-btn, + .alert-filter-bar .bulk-acknowledge-btn { + width: 100%; + } + + .alert-actions { + flex-direction: column; + } + + .alert-action-btn { + width: 100%; + } +} + +@media (max-width: 480px) { + .alert-summary-cards { + grid-template-columns: 1fr; + } + + .alert-dashboard { + padding: 1rem; + } + + .alert-dashboard .panel-head { + padding: 1rem; + } +} diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index f38cc77..85f1ea9 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -1,10 +1,15 @@ import type { AgentBaseline, + AlertPolicy, + AlertStatus, + AlertSummary, + AlertTrendingPoint, AnalyticsResponse, CostSummary, DriftResponse, FixNoteResponse, LiveSummary, + ManagedAlert, ReplayResponse, SearchResponse, Session, @@ -371,3 +376,95 @@ export async function getSimilarFailures(params: { `${API_BASE}/sessions/${params.sessionId}/similar-failures?${search.toString()}` ) } + +// Alert Dashboard API functions +export async function fetchAlerts(filters?: Record) { + const search = new URLSearchParams() + if (filters) { + Object.entries(filters).forEach(([key, value]) => { + if (value) search.set(key, value) + }) + } + const queryString = search.toString() ? `?${search.toString()}` : '' + return fetchJSON<{ alerts: ManagedAlert[]; total: number }>( + `${API_BASE}/alerts${queryString}` + ) +} + +export async function updateAlertStatus( + alertId: string, + status: AlertStatus, + note?: string +) { + const response = await fetch(`${API_BASE}/alerts/${alertId}/status`, { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ status, note }), + }) + if (!response.ok) { + throw new Error(`API error: ${response.status} ${response.statusText}`) + } + return response.json() as Promise +} + +export async function bulkUpdateAlertStatus(alertIds: string[], status: AlertStatus) { + const response = await fetch(`${API_BASE}/alerts/bulk-status`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ alert_ids: alertIds, status }), + }) + if (!response.ok) { + throw new Error(`API error: ${response.status} ${response.statusText}`) + } + return response.json() as Promise<{ updated: number; status: AlertStatus }> +} + +export async function fetchAlertSummary() { + return fetchJSON(`${API_BASE}/alerts/summary`) +} + +export async function fetchAlertTrending(days: number = 7) { + const data = await fetchJSON<{ trending: AlertTrendingPoint[]; days: number }>( + `${API_BASE}/alerts/trending?days=${days}` + ) + return data.trending +} + +export async function fetchAlertPolicies(agentName?: string) { + const params = agentName ? `?agent_name=${encodeURIComponent(agentName)}` : '' + return fetchJSON(`${API_BASE}/alert-policies${params}`) +} + +export async function createAlertPolicy(policy: Partial) { + const response = await fetch(`${API_BASE}/alert-policies`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(policy), + }) + if (!response.ok) { + throw new Error(`API error: ${response.status} ${response.statusText}`) + } + return response.json() as Promise +} + +export async function updateAlertPolicy(id: string, policy: Partial) { + const response = await fetch(`${API_BASE}/alert-policies/${id}`, { + method: 'PUT', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(policy), + }) + if (!response.ok) { + throw new Error(`API error: ${response.status} ${response.statusText}`) + } + return response.json() as Promise +} + +export async function deleteAlertPolicy(id: string) { + const response = await fetch(`${API_BASE}/alert-policies/${id}`, { + method: 'DELETE', + }) + if (!response.ok) { + throw new Error(`API error: ${response.status} ${response.statusText}`) + } + return response.json() as Promise<{ deleted: boolean }> +} diff --git a/frontend/src/components/AlertDashboardPanel.tsx b/frontend/src/components/AlertDashboardPanel.tsx new file mode 100644 index 0000000..d2fe139 --- /dev/null +++ b/frontend/src/components/AlertDashboardPanel.tsx @@ -0,0 +1,316 @@ +import { useState } from 'react' +import { useAlerts } from '../hooks/useAlerts' +import { useAlertSummary } from '../hooks/useAlertSummary' +import type { AlertStatus } from '../types' +import { severityLabel } from '../types' + +interface AlertDashboardPanelProps { + agentName: string | null +} + +export function AlertDashboardPanel({ agentName }: AlertDashboardPanelProps) { + const { alerts, loading, error, filters, setFilter, clearAllFilters, updateStatus, bulkUpdate } = + useAlerts(agentName ? { agent_name: agentName } : {}) + const { summary, trending, loading: summaryLoading } = useAlertSummary(7) + const [expandedAlertId, setExpandedAlertId] = useState(null) + const [resolutionNote, setResolutionNote] = useState('') + const [resolvingAlertId, setResolvingAlertId] = useState(null) + + const handleStatusChange = async (alertId: string, status: AlertStatus, note?: string) => { + try { + await updateStatus(alertId, status, note) + if (status === 'resolved' || status === 'dismissed') { + setExpandedAlertId(null) + setResolutionNote('') + } + } catch (err) { + console.error('Failed to update alert status:', err) + } + } + + const handleBulkAcknowledge = async () => { + const activeAlerts = alerts.filter((a) => a.status === 'active').map((a) => a.id) + if (activeAlerts.length === 0) return + try { + await bulkUpdate(activeAlerts, 'acknowledged') + } catch (err) { + console.error('Failed to bulk acknowledge:', err) + } + } + + const handleResolve = async (alertId: string) => { + setResolvingAlertId(alertId) + try { + await handleStatusChange(alertId, 'resolved', resolutionNote || undefined) + } finally { + setResolvingAlertId(null) + } + } + + const getSeverityColor = (severity: number): string => { + const label = severityLabel(severity) + switch (label) { + case 'critical': + return 'var(--danger)' + case 'high': + return 'oklch(0.58 0.22 25)' + case 'medium': + return 'var(--warning)' + case 'low': + return 'var(--olive)' + default: + return 'var(--muted)' + } + } + + const getStatusVariant = (status: AlertStatus): string => { + switch (status) { + case 'active': + return 'alert-row--active' + case 'acknowledged': + return 'alert-row--acknowledged' + case 'resolved': + return 'alert-row--resolved' + case 'dismissed': + return 'alert-row--dismissed' + default: + return '' + } + } + + if (summaryLoading && !summary) { + return ( +
+
+

Alerts

+

Alert Dashboard

+
+
Loading alert data...
+
+ ) + } + + return ( +
+
+

Alert Management

+

+ Alert Dashboard + {summary && summary.total > 0 && {summary.total}} +

+
+ + {error &&
{error}
} + + {/* Summary Cards */} + {summary && ( +
+
+ Total Alerts + {summary.total} +
+
+ Critical + {summary.by_severity.critical || 0} +
+
+ Warning + {summary.by_severity.high || 0} +
+
+ Active + {summary.by_status.active || 0} +
+
+ )} + + {/* Filter Bar */} +
+ + + + {(filters.severity || filters.status || filters.alert_type) && ( + + )} + {alerts.some((a) => a.status === 'active') && ( + + )} +
+ + {/* Alert List */} +
+ {loading ? ( +
Loading alerts...
+ ) : alerts.length === 0 ? ( +
+
🔔
+

No alerts

+

No alerts match the current filters.

+ Alerts will appear here when behavior patterns need attention +
+ ) : ( + alerts.map((alert) => ( +
setExpandedAlertId(expandedAlertId === alert.id ? null : alert.id)} + onKeyDown={(e) => { + if (e.key === 'Enter' || e.key === ' ') { + e.preventDefault() + setExpandedAlertId(expandedAlertId === alert.id ? null : alert.id) + } + }} + > +
+
+ + {alert.alert_type} + {severityLabel(alert.severity)} + {alert.status} +
+ + {new Date(alert.created_at).toLocaleString()} + +
+

{alert.signal}

+ + {expandedAlertId === alert.id && ( +
e.stopPropagation()}> +
+ Session ID: + {alert.session_id} +
+
+ Detection Source: + {alert.detection_source} +
+
+ Events: + {alert.event_ids.length} linked +
+ {alert.resolution_note && ( +
+ Resolution Note: + {alert.resolution_note} +
+ )} + + {/* Resolution note input for active/acknowledged alerts */} + {(alert.status === 'active' || alert.status === 'acknowledged') && ( +
+ + setResolutionNote(e.target.value)} + placeholder="Add resolution note..." + className="resolution-input" + /> +
+ {alert.status === 'active' && ( + + )} + + +
+
+ )} +
+ )} +
+ )) + )} +
+ + {/* Trending Section */} + {trending && trending.length > 0 && ( +
+

Alert Volume (Last 7 Days)

+
+ {trending.map((point) => { + const maxCount = Math.max(...trending.map((p) => p.count)) + const heightPercent = maxCount > 0 ? (point.count / maxCount) * 100 : 0 + return ( +
+
+ {new Date(point.date).toLocaleDateString(undefined, { month: 'short', day: 'numeric' })} +
+ ) + })} +
+
+ )} +
+ ) +} diff --git a/frontend/src/components/AnalyticsTab.tsx b/frontend/src/components/AnalyticsTab.tsx index 66baecc..a0f5b59 100644 --- a/frontend/src/components/AnalyticsTab.tsx +++ b/frontend/src/components/AnalyticsTab.tsx @@ -1,5 +1,6 @@ import CostSummary from './CostSummary' import { AnalyticsPanel } from './AnalyticsPanel' +import { AlertDashboardPanel } from './AlertDashboardPanel' import './AnalyticsTab.css' export function AnalyticsTab() { @@ -7,6 +8,7 @@ export function AnalyticsTab() {
+
) } diff --git a/frontend/src/hooks/useAlertSummary.ts b/frontend/src/hooks/useAlertSummary.ts new file mode 100644 index 0000000..35f7d24 --- /dev/null +++ b/frontend/src/hooks/useAlertSummary.ts @@ -0,0 +1,51 @@ +import { useEffect, useState } from 'react' +import { fetchAlertSummary, fetchAlertTrending } from '../api/client' +import type { AlertSummary, AlertTrendingPoint } from '../types' + +interface UseAlertSummaryReturn { + summary: AlertSummary | null + trending: AlertTrendingPoint[] + loading: boolean + error: string | null + refresh: () => Promise +} + +export function useAlertSummary(days: number = 7): UseAlertSummaryReturn { + const [summary, setSummary] = useState(null) + const [trending, setTrending] = useState([]) + const [loading, setLoading] = useState(false) + const [error, setError] = useState(null) + + const loadSummary = async () => { + setLoading(true) + setError(null) + try { + const [summaryData, trendingData] = await Promise.all([ + fetchAlertSummary(), + fetchAlertTrending(days), + ]) + setSummary(summaryData) + setTrending(trendingData) + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to load alert summary') + } finally { + setLoading(false) + } + } + + const refresh = async () => { + await loadSummary() + } + + useEffect(() => { + void loadSummary() + }, [days]) + + return { + summary, + trending, + loading, + error, + refresh, + } +} diff --git a/frontend/src/hooks/useAlerts.ts b/frontend/src/hooks/useAlerts.ts new file mode 100644 index 0000000..c443412 --- /dev/null +++ b/frontend/src/hooks/useAlerts.ts @@ -0,0 +1,98 @@ +import { useEffect, useState } from 'react' +import { fetchAlerts, updateAlertStatus, bulkUpdateAlertStatus } from '../api/client' +import type { AlertStatus, ManagedAlert } from '../types' + +interface UseAlertsReturn { + alerts: ManagedAlert[] + loading: boolean + error: string | null + filters: Record + setFilter: (key: string, value: string) => void + clearFilter: (key: string) => void + clearAllFilters: () => void + updateStatus: (alertId: string, status: AlertStatus, note?: string) => Promise + bulkUpdate: (alertIds: string[], status: AlertStatus) => Promise + refresh: () => Promise +} + +const DEFAULT_FILTERS: Record = {} + +export function useAlerts(initialFilters: Record = DEFAULT_FILTERS): UseAlertsReturn { + const [alerts, setAlerts] = useState([]) + const [loading, setLoading] = useState(false) + const [error, setError] = useState(null) + const [filters, setFilters] = useState>(initialFilters) + + const loadAlerts = async () => { + setLoading(true) + setError(null) + try { + const response = await fetchAlerts(filters) + setAlerts(response.alerts) + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to load alerts') + setAlerts([]) + } finally { + setLoading(false) + } + } + + const setFilter = (key: string, value: string) => { + setFilters((prev) => ({ ...prev, [key]: value })) + } + + const clearFilter = (key: string) => { + setFilters((prev) => { + const next = { ...prev } + delete next[key] + return next + }) + } + + const clearAllFilters = () => { + setFilters(DEFAULT_FILTERS) + } + + const updateStatus = async (alertId: string, status: AlertStatus, note?: string) => { + try { + const updated = await updateAlertStatus(alertId, status, note) + setAlerts((prev) => prev.map((alert) => (alert.id === alertId ? updated : alert))) + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to update alert status') + throw err + } + } + + const bulkUpdate = async (alertIds: string[], status: AlertStatus) => { + try { + await bulkUpdateAlertStatus(alertIds, status) + // Refresh to get updated state from server + await loadAlerts() + } catch (err) { + setError(err instanceof Error ? err.message : 'Failed to bulk update alerts') + throw err + } + } + + const refresh = async () => { + await loadAlerts() + } + + // Auto-refresh on filter change + useEffect(() => { + void loadAlerts() + }, [filters]) + + return { + alerts, + loading, + error, + filters, + setFilter, + clearFilter, + clearAllFilters, + updateStatus, + bulkUpdate, + refresh, + } +} diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 625d483..a504f28 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -462,3 +462,55 @@ export interface SimilarFailuresResponse { similar_failures: SimilarFailure[] total: number } + +// Alert Dashboard types +export type AlertStatus = 'active' | 'acknowledged' | 'resolved' | 'dismissed' + +export interface AlertPolicy { + id: string + agent_name: string | null + alert_type: string + threshold_value: number + severity_threshold: string | null + enabled: boolean + created_at: string + updated_at: string +} + +export interface AlertSummary { + by_severity: Record + by_status: Record + by_type: Record + total: number +} + +export interface AlertTrendingPoint { + date: string + count: number +} + +// Extended alert with lifecycle management +export interface ManagedAlert { + id: string + session_id: string + alert_type: string + severity: number + signal: string + status: AlertStatus + event_ids: string[] + detection_source: string + detection_config: Record + resolution_note: string | null + acknowledged_at: string | null + resolved_at: string | null + dismissed_at: string | null + created_at: string +} + +/** Map numeric severity (0.0–1.0) to a display label. */ +export function severityLabel(severity: number): RiskLevel { + if (severity >= 0.8) return 'critical' + if (severity >= 0.5) return 'high' + if (severity >= 0.3) return 'medium' + return 'low' +} diff --git a/storage/__init__.py b/storage/__init__.py index 8a6dcfe..6792561 100644 --- a/storage/__init__.py +++ b/storage/__init__.py @@ -5,6 +5,7 @@ """ from .models import ( + AlertPolicyModel, AnomalyAlertModel, Base, CheckpointModel, @@ -12,7 +13,13 @@ FailureClusterModel, SessionModel, ) -from .repositories import AnomalyAlertRepository, CheckpointRepository, EventRepository, SessionRepository +from .repositories import ( + AlertPolicyRepository, + AnomalyAlertRepository, + CheckpointRepository, + EventRepository, + SessionRepository, +) from .repository import AnomalyAlertCreate, TraceRepository from .search import SessionSearchService @@ -25,6 +32,7 @@ "EventRepository", "CheckpointRepository", "AnomalyAlertRepository", + "AlertPolicyRepository", # Search service "SessionSearchService", # Models @@ -34,4 +42,5 @@ "CheckpointModel", "AnomalyAlertModel", "FailureClusterModel", + "AlertPolicyModel", ] diff --git a/storage/cache.py b/storage/cache.py new file mode 100644 index 0000000..b0be066 --- /dev/null +++ b/storage/cache.py @@ -0,0 +1,96 @@ +"""Simple in-memory query cache with TTL support.""" + +from __future__ import annotations + +import threading +import time +from typing import Any + + +class QueryCache: + """Thread-safe in-memory cache with time-based expiration. + + Simple caching utility for query results that don't change frequently. + Uses a dictionary-based storage with per-entry TTL support. + """ + + def __init__(self) -> None: + """Initialize the cache with storage and lock.""" + self._cache: dict[str, tuple[Any, float]] = {} + self._lock = threading.Lock() + + def get(self, key: str) -> Any | None: + """Retrieve a value from the cache if it exists and hasn't expired. + + Args: + key: Cache key to look up + + Returns: + Cached value if found and not expired, None otherwise + """ + with self._lock: + entry = self._cache.get(key) + if entry is None: + return None + + value, expiry = entry + if time.time() > expiry: + # Expired, remove from cache + del self._cache[key] + return None + + return value + + def set(self, key: str, value: Any, ttl_seconds: int = 60) -> None: + """Store a value in the cache with a TTL. + + Args: + key: Cache key to store under + value: Value to cache + ttl_seconds: Time-to-live in seconds (default: 60) + """ + expiry = time.time() + ttl_seconds + with self._lock: + self._cache[key] = (value, expiry) + + def invalidate(self, key: str, *, prefix: bool = False) -> int: + """Remove cache entries by exact key or by key prefix. + + Args: + key: Exact cache key to invalidate, or prefix when ``prefix=True`` + prefix: When True, remove all entries whose keys start with ``key`` + + Returns: + Number of entries removed + """ + with self._lock: + if not prefix: + return 1 if self._cache.pop(key, None) is not None else 0 + + keys_to_remove = [k for k in self._cache if k.startswith(key)] + for k in keys_to_remove: + del self._cache[k] + return len(keys_to_remove) + + def clear(self) -> None: + """Clear all entries from the cache.""" + with self._lock: + self._cache.clear() + + def cleanup_expired(self) -> int: + """Remove all expired entries from the cache. + + Returns: + Number of entries removed + """ + now = time.time() + with self._lock: + expired_keys = [k for k, (_, expiry) in self._cache.items() if now > expiry] + for key in expired_keys: + del self._cache[key] + return len(expired_keys) + + def size(self) -> int: + """Return the current number of entries in the cache.""" + with self._lock: + return len(self._cache) diff --git a/storage/migrations/versions/006_add_alert_lifecycle.py b/storage/migrations/versions/006_add_alert_lifecycle.py new file mode 100644 index 0000000..d655a31 --- /dev/null +++ b/storage/migrations/versions/006_add_alert_lifecycle.py @@ -0,0 +1,100 @@ +"""Add alert lifecycle fields to anomaly_alerts table. + +Revision ID: 006_add_alert_lifecycle +Revises: 005_add_patterns +Create Date: 2026-04-04 + +""" + +from collections.abc import Sequence +from typing import Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "006_add_alert_lifecycle" +down_revision: Union[str, None] = "005_add_patterns" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Check if anomaly_alerts table exists + if "anomaly_alerts" in inspector.get_table_names(): + # Get existing columns to check if they already exist + existing_columns = {col["name"] for col in inspector.get_columns("anomaly_alerts")} + + # Add status column if it doesn't exist + if "status" not in existing_columns: + op.add_column( + "anomaly_alerts", + sa.Column("status", sa.String(32), nullable=False, server_default="active", index=True), + ) + + # Add acknowledged_at column if it doesn't exist + if "acknowledged_at" not in existing_columns: + op.add_column( + "anomaly_alerts", + sa.Column("acknowledged_at", sa.DateTime(), nullable=True), + ) + + # Add resolved_at column if it doesn't exist + if "resolved_at" not in existing_columns: + op.add_column( + "anomaly_alerts", + sa.Column("resolved_at", sa.DateTime(), nullable=True), + ) + + # Add dismissed_at column if it doesn't exist + if "dismissed_at" not in existing_columns: + op.add_column( + "anomaly_alerts", + sa.Column("dismissed_at", sa.DateTime(), nullable=True), + ) + + # Add resolution_note column if it doesn't exist + if "resolution_note" not in existing_columns: + op.add_column( + "anomaly_alerts", + sa.Column("resolution_note", sa.Text(), nullable=True), + ) + + # Create composite index for status if it doesn't exist + existing_indexes = {idx["name"] for idx in inspector.get_indexes("anomaly_alerts")} + if "ix_anomaly_alerts_tenant_id_status" not in existing_indexes: + op.create_index( + "ix_anomaly_alerts_tenant_id_status", + "anomaly_alerts", + ["tenant_id", "status"], + ) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Check if anomaly_alerts table exists before modifying + if "anomaly_alerts" in inspector.get_table_names(): + # Drop index if it exists + existing_indexes = {idx["name"] for idx in inspector.get_indexes("anomaly_alerts")} + if "ix_anomaly_alerts_tenant_id_status" in existing_indexes: + op.drop_index("ix_anomaly_alerts_tenant_id_status", table_name="anomaly_alerts") + + # Get existing columns + existing_columns = {col["name"] for col in inspector.get_columns("anomaly_alerts")} + + # Drop columns if they exist + if "resolution_note" in existing_columns: + op.drop_column("anomaly_alerts", "resolution_note") + if "dismissed_at" in existing_columns: + op.drop_column("anomaly_alerts", "dismissed_at") + if "resolved_at" in existing_columns: + op.drop_column("anomaly_alerts", "resolved_at") + if "acknowledged_at" in existing_columns: + op.drop_column("anomaly_alerts", "acknowledged_at") + if "status" in existing_columns: + op.drop_column("anomaly_alerts", "status") diff --git a/storage/migrations/versions/007_add_alert_policies.py b/storage/migrations/versions/007_add_alert_policies.py new file mode 100644 index 0000000..708d3e8 --- /dev/null +++ b/storage/migrations/versions/007_add_alert_policies.py @@ -0,0 +1,68 @@ +"""Add alert_policies table for configurable alert thresholds. + +Revision ID: 007_add_alert_policies +Revises: 006_add_alert_lifecycle +Create Date: 2026-04-04 + +""" + +from collections.abc import Sequence +from typing import Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "007_add_alert_policies" +down_revision: Union[str, None] = "006_add_alert_lifecycle" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Check if alert_policies table already exists + if "alert_policies" not in inspector.get_table_names(): + op.create_table( + "alert_policies", + sa.Column("id", sa.String(36), primary_key=True), + sa.Column("tenant_id", sa.String(64), nullable=False, index=True, server_default="local"), + sa.Column("agent_name", sa.String(255), nullable=True, index=True), + sa.Column("alert_type", sa.String(64), nullable=False, index=True), + sa.Column("threshold_value", sa.Float(), nullable=False), + sa.Column("severity_threshold", sa.String(16), nullable=True), + sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"), + sa.Column( + "created_at", + sa.DateTime(), + nullable=False, + server_default=sa.text("CURRENT_TIMESTAMP"), + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + server_default=sa.text("CURRENT_TIMESTAMP"), + ), + ) + + # Create composite indexes for common query patterns + op.create_index("ix_alert_policies_tenant_agent", "alert_policies", ["tenant_id", "agent_name"]) + op.create_index("ix_alert_policies_tenant_type", "alert_policies", ["tenant_id", "alert_type"]) + op.create_index( + "ix_alert_policies_tenant_agent_type", "alert_policies", ["tenant_id", "agent_name", "alert_type"] + ) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Check if alert_policies table exists before dropping (idempotency) + if "alert_policies" in inspector.get_table_names(): + op.drop_index("ix_alert_policies_tenant_agent_type", table_name="alert_policies") + op.drop_index("ix_alert_policies_tenant_type", table_name="alert_policies") + op.drop_index("ix_alert_policies_tenant_agent", table_name="alert_policies") + op.drop_table("alert_policies") diff --git a/storage/migrations/versions/008_add_alert_indexes.py b/storage/migrations/versions/008_add_alert_indexes.py new file mode 100644 index 0000000..c9dea5b --- /dev/null +++ b/storage/migrations/versions/008_add_alert_indexes.py @@ -0,0 +1,112 @@ +"""Add indexes for alert and analytics query optimization. + +Revision ID: 008_add_alert_indexes +Revises: 007_add_alert_policies +Create Date: 2026-04-04 + +""" + +from collections.abc import Sequence +from typing import Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "008_add_alert_indexes" +down_revision: Union[str, None] = "007_add_alert_policies" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + # Get existing indexes to avoid duplicates + def index_exists(table_name: str, index_name: str) -> bool: + existing_indexes = inspector.get_indexes(table_name) + return any(idx["name"] == index_name for idx in existing_indexes) + + # Create indexes for anomaly_alerts table + if not index_exists("anomaly_alerts", "ix_anomaly_alerts_created_at"): + op.create_index("ix_anomaly_alerts_created_at", "anomaly_alerts", ["created_at"]) + + if not index_exists("anomaly_alerts", "ix_anomaly_alerts_severity"): + op.create_index("ix_anomaly_alerts_severity", "anomaly_alerts", ["severity"]) + + if not index_exists("anomaly_alerts", "ix_anomaly_alerts_alert_type"): + op.create_index("ix_anomaly_alerts_alert_type", "anomaly_alerts", ["alert_type"]) + + if not index_exists("anomaly_alerts", "ix_anomaly_alerts_session_id"): + op.create_index("ix_anomaly_alerts_session_id", "anomaly_alerts", ["session_id"]) + + # Status column was added in migration 006, now index it + if not index_exists("anomaly_alerts", "ix_anomaly_alerts_status"): + op.create_index("ix_anomaly_alerts_status", "anomaly_alerts", ["status"]) + + # Create indexes for patterns table (individual columns for additional query patterns) + if not index_exists("patterns", "ix_patterns_status"): + op.create_index("ix_patterns_status", "patterns", ["status"]) + + if not index_exists("patterns", "ix_patterns_pattern_type"): + op.create_index("ix_patterns_pattern_type", "patterns", ["pattern_type"]) + + # Create indexes for sessions table + if not index_exists("sessions", "ix_sessions_started_at"): + op.create_index("ix_sessions_started_at", "sessions", ["started_at"]) + + if not index_exists("sessions", "ix_sessions_agent_name"): + op.create_index("ix_sessions_agent_name", "sessions", ["agent_name"]) + + # Create composite index for events table (session_id, created_at) + # This optimizes queries that filter by session and order by timestamp + if not index_exists("events", "ix_events_session_id_created_at"): + op.create_index("ix_events_session_id_created_at", "events", ["session_id", "timestamp"]) + + # Create index for events.event_type (if not exists) + if not index_exists("events", "ix_events_event_type"): + op.create_index("ix_events_event_type", "events", ["event_type"]) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + + def index_exists(table_name: str, index_name: str) -> bool: + existing_indexes = inspector.get_indexes(table_name) + return any(idx["name"] == index_name for idx in existing_indexes) + + # Drop indexes in reverse order + if index_exists("events", "ix_events_event_type"): + op.drop_index("ix_events_event_type", table_name="events") + + if index_exists("events", "ix_events_session_id_created_at"): + op.drop_index("ix_events_session_id_created_at", table_name="events") + + if index_exists("sessions", "ix_sessions_agent_name"): + op.drop_index("ix_sessions_agent_name", table_name="sessions") + + if index_exists("sessions", "ix_sessions_created_at"): + op.drop_index("ix_sessions_created_at", table_name="sessions") + + if index_exists("patterns", "ix_patterns_pattern_type"): + op.drop_index("ix_patterns_pattern_type", table_name="patterns") + + if index_exists("patterns", "ix_patterns_status"): + op.drop_index("ix_patterns_status", table_name="patterns") + + if index_exists("anomaly_alerts", "ix_anomaly_alerts_session_id"): + op.drop_index("ix_anomaly_alerts_session_id", table_name="anomaly_alerts") + + if index_exists("anomaly_alerts", "ix_anomaly_alerts_alert_type"): + op.drop_index("ix_anomaly_alerts_alert_type", table_name="anomaly_alerts") + + if index_exists("anomaly_alerts", "ix_anomaly_alerts_severity"): + op.drop_index("ix_anomaly_alerts_severity", table_name="anomaly_alerts") + + if index_exists("anomaly_alerts", "ix_anomaly_alerts_status"): + op.drop_index("ix_anomaly_alerts_status", table_name="anomaly_alerts") + + if index_exists("anomaly_alerts", "ix_anomaly_alerts_created_at"): + op.drop_index("ix_anomaly_alerts_created_at", table_name="anomaly_alerts") diff --git a/storage/models.py b/storage/models.py index 460a43e..2f55d8d 100644 --- a/storage/models.py +++ b/storage/models.py @@ -104,6 +104,12 @@ class AnomalyAlertModel(Base): detection_source: Mapped[str] = mapped_column(String(32)) detection_config: Mapped[dict] = mapped_column(JSON) created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(timezone.utc)) + # Lifecycle fields + status: Mapped[str] = mapped_column(String(32), default="active", index=True) + acknowledged_at: Mapped[datetime | None] = mapped_column(nullable=True) + resolved_at: Mapped[datetime | None] = mapped_column(nullable=True) + dismissed_at: Mapped[datetime | None] = mapped_column(nullable=True) + resolution_note: Mapped[str | None] = mapped_column(Text, nullable=True) class FailureClusterModel(Base): @@ -159,3 +165,28 @@ class PatternModel(Base): Index("ix_patterns_tenant_severity", "tenant_id", "severity"), Index("ix_patterns_tenant_status", "tenant_id", "status"), ) + + +class AlertPolicyModel(Base): + """SQLAlchemy ORM model for configurable alert policies.""" + + __tablename__ = "alert_policies" + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + tenant_id: Mapped[str] = mapped_column(String(64), nullable=False, default="local", index=True) + agent_name: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True) # null = global policy + alert_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + threshold_value: Mapped[float] = mapped_column(Float, nullable=False) + severity_threshold: Mapped[str | None] = mapped_column(String(16), nullable=True) # warning, critical, etc. + enabled: Mapped[bool] = mapped_column(default=True, nullable=False) + created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(timezone.utc)) + updated_at: Mapped[datetime] = mapped_column( + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + + __table_args__ = ( + Index("ix_alert_policies_tenant_agent", "tenant_id", "agent_name"), + Index("ix_alert_policies_tenant_type", "tenant_id", "alert_type"), + Index("ix_alert_policies_tenant_agent_type", "tenant_id", "agent_name", "alert_type"), + ) diff --git a/storage/repositories/__init__.py b/storage/repositories/__init__.py index 0bb89fc..99b2b29 100644 --- a/storage/repositories/__init__.py +++ b/storage/repositories/__init__.py @@ -6,6 +6,7 @@ - CheckpointRepository: Checkpoint CRUD operations - AnomalyAlertRepository: Anomaly alert CRUD operations - PatternRepository: Pattern CRUD operations +- AlertPolicyRepository: Alert policy CRUD operations """ from .alert_repo import AnomalyAlertRepository @@ -13,6 +14,7 @@ from .entity_repo import EntityRepository from .event_repo import EventRepository from .pattern_repo import PatternRepository +from .policy_repo import AlertPolicyRepository from .session_repo import SessionRepository __all__ = [ @@ -21,5 +23,6 @@ "CheckpointRepository", "AnomalyAlertRepository", "PatternRepository", + "AlertPolicyRepository", "EntityRepository", ] diff --git a/storage/repositories/alert_repo.py b/storage/repositories/alert_repo.py index 7fd116e..56f196f 100644 --- a/storage/repositories/alert_repo.py +++ b/storage/repositories/alert_repo.py @@ -2,9 +2,13 @@ from __future__ import annotations -from sqlalchemy import select +from datetime import datetime, timedelta, timezone +from typing import Any + +from sqlalchemy import and_, func, select, update from sqlalchemy.ext.asyncio import AsyncSession +from storage.cache import QueryCache from storage.models import AnomalyAlertModel @@ -13,8 +17,20 @@ class AnomalyAlertRepository: Provides async methods for alert management using SQLAlchemy async session. All queries are scoped to a specific tenant_id for multi-tenant isolation. + + Queries leverage the following indexes: + - ix_anomaly_alerts_created_at: time-based ordering and filtering + - ix_anomaly_alerts_severity: severity-based filtering + - ix_anomaly_alerts_alert_type: alert type filtering + - ix_anomaly_alerts_session_id: session-based lookups + - ix_anomaly_alerts_tenant_id_status: tenant + status filtering """ + VALID_STATUSES = {"active", "acknowledged", "resolved", "dismissed"} + + # Class-level cache shared across instances (for summary/trending data) + _cache = QueryCache() + def __init__(self, session: AsyncSession, tenant_id: str = "local"): """Initialize the repository with an async session and tenant_id. @@ -38,6 +54,8 @@ async def create_anomaly_alert( The created AnomalyAlertModel instance """ self.session.add(alert) + # Invalidate summary cache when new alert is created + self._invalidate_summary_cache() return alert async def list_anomaly_alerts( @@ -47,12 +65,14 @@ async def list_anomaly_alerts( ) -> list[AnomalyAlertModel]: """List anomaly alerts for a session. + Uses ix_anomaly_alerts_session_id and ix_anomaly_alerts_created_at indexes. + Args: session_id: Session ID to filter alerts by - limit: Maximum number of alerts to return + limit: Maximum number of alerts to return (default: 50) Returns: - List of AnomalyAlertModel instances + List of AnomalyAlertModel instances ordered by creation time (newest first) """ result = await self.session.execute( select(AnomalyAlertModel) @@ -81,3 +101,417 @@ async def get_anomaly_alert(self, alert_id: str) -> AnomalyAlertModel | None: ) ) return result.scalar_one_or_none() + + async def get_alert_summary(self, hours: int = 24) -> dict[str, Any]: + """Get aggregated alert statistics for the recent time window. + + Uses ix_anomaly_alerts_created_at index for time filtering and + ix_anomaly_alerts_severity, ix_anomaly_alerts_alert_type for grouping. + + Results are cached for 60 seconds to reduce database load. + + Args: + hours: Number of hours to look back (default: 24) + + Returns: + Dictionary with total_count, by_severity, by_type, and by_session + """ + cache_key = f"alert_summary:{self.tenant_id}:{hours}h" + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) + + # Count by severity - uses ix_anomaly_alerts_severity + severity_result = await self.session.execute( + select(AnomalyAlertModel.severity, func.count(AnomalyAlertModel.id)) + .where( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.created_at >= cutoff, + ) + .group_by(AnomalyAlertModel.severity) + ) + by_severity = {row[0]: row[1] for row in severity_result.all()} + + # Count by alert type - uses ix_anomaly_alerts_alert_type + type_result = await self.session.execute( + select(AnomalyAlertModel.alert_type, func.count(AnomalyAlertModel.id)) + .where( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.created_at >= cutoff, + ) + .group_by(AnomalyAlertModel.alert_type) + ) + by_type = {row[0]: row[1] for row in type_result.all()} + + # Count by session - uses ix_anomaly_alerts_session_id + session_result = await self.session.execute( + select(AnomalyAlertModel.session_id, func.count(AnomalyAlertModel.id)) + .where( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.created_at >= cutoff, + ) + .group_by(AnomalyAlertModel.session_id) + .order_by(func.count(AnomalyAlertModel.id).desc()) + .limit(10) + ) + by_session = {row[0]: row[1] for row in session_result.all()} + + summary = { + "total_count": sum(by_severity.values()), + "by_severity": by_severity, + "by_type": by_type, + "by_session": by_session, + "period_hours": hours, + } + + # Cache for 60 seconds + self._cache.set(cache_key, summary, ttl_seconds=60) + return summary + + async def get_trending_alerts( + self, + hours: int = 24, + limit: int = 10, + ) -> list[dict[str, Any]]: + """Get trending alerts by type for the recent time window. + + Uses ix_anomaly_alerts_alert_type and ix_anomaly_alerts_created_at indexes. + + Results are cached for 60 seconds to reduce database load. + + Args: + hours: Number of hours to look back (default: 24) + limit: Maximum number of trending types to return (default: 10) + + Returns: + List of dicts with alert_type, count, and avg_severity + """ + cache_key = f"trending_alerts:{self.tenant_id}:{hours}h:{limit}" + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + cutoff = datetime.now(timezone.utc) - timedelta(hours=hours) + + result = await self.session.execute( + select( + AnomalyAlertModel.alert_type, + func.count(AnomalyAlertModel.id).label("count"), + func.avg(AnomalyAlertModel.severity).label("avg_severity"), + ) + .where( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.created_at >= cutoff, + ) + .group_by(AnomalyAlertModel.alert_type) + .order_by(func.count(AnomalyAlertModel.id).desc()) + .limit(limit) + ) + + trending = [ + { + "alert_type": row.alert_type, + "count": row.count, + "avg_severity": float(row.avg_severity) if row.avg_severity else 0.0, + } + for row in result.all() + ] + + # Cache for 60 seconds + self._cache.set(cache_key, trending, ttl_seconds=60) + return trending + + def _invalidate_summary_cache(self) -> None: + """Invalidate summary, lifecycle, and trending cache entries for this tenant.""" + self._cache.invalidate(f"alert_summary:{self.tenant_id}:", prefix=True) + self._cache.invalidate(f"trending_alerts:{self.tenant_id}:", prefix=True) + self._cache.invalidate(f"trending:{self.tenant_id}:", prefix=True) + self._cache.invalidate(f"lifecycle_summary:{self.tenant_id}") + + # ------------------------------------------------------------------ + # Lifecycle Management Methods + # ------------------------------------------------------------------ + + async def update_alert_status( + self, alert_id: str, status: str, note: str | None = None + ) -> AnomalyAlertModel | None: + """Update the status of a single alert with appropriate timestamp. + + Uses primary key lookup for efficient single-alert update. + + Args: + alert_id: Unique identifier of the alert + status: New status (active/acknowledged/resolved/dismissed) + note: Optional resolution note + + Returns: + Updated AnomalyAlertModel if found, None otherwise + + Raises: + ValueError: If status is not valid + """ + if status not in self.VALID_STATUSES: + raise ValueError(f"Invalid status: {status}. Must be one of {self.VALID_STATUSES}") + + result = await self.session.execute( + select(AnomalyAlertModel).where( + AnomalyAlertModel.id == alert_id, + AnomalyAlertModel.tenant_id == self.tenant_id, + ) + ) + alert = result.scalar_one_or_none() + + if not alert: + return None + + # Update status and note + alert.status = status + if note: + alert.resolution_note = note + + # Update appropriate timestamp + now = datetime.now(timezone.utc) + if status == "acknowledged": + alert.acknowledged_at = now + elif status == "resolved": + alert.resolved_at = now + elif status == "dismissed": + alert.dismissed_at = now + + # Invalidate cache when alert status changes + self._invalidate_summary_cache() + + return alert + + async def bulk_update_status(self, alert_ids: list[str], status: str) -> int: + """Bulk update status for multiple alerts. + + Uses efficient bulk update with WHERE ... IN clause. + + Args: + alert_ids: List of alert IDs to update + status: New status for all alerts + + Returns: + Number of alerts updated + + Raises: + ValueError: If status is not valid + """ + if status not in self.VALID_STATUSES: + raise ValueError(f"Invalid status: {status}. Must be one of {self.VALID_STATUSES}") + + now = datetime.now(timezone.utc) + updates = {"status": status} + + # Set appropriate timestamp based on status + if status == "acknowledged": + updates["acknowledged_at"] = now + elif status == "resolved": + updates["resolved_at"] = now + elif status == "dismissed": + updates["dismissed_at"] = now + + stmt = ( + update(AnomalyAlertModel) + .where( + AnomalyAlertModel.id.in_(alert_ids), + AnomalyAlertModel.tenant_id == self.tenant_id, + ) + .values(**updates) + ) + result = await self.session.execute(stmt) + + # Invalidate cache when bulk status changes occur + self._invalidate_summary_cache() + + return result.rowcount + + async def list_alerts_filtered( + self, + agent_name: str | None = None, + severity: float | None = None, + alert_type: str | None = None, + status: str | None = None, + from_date: datetime | None = None, + to_date: datetime | None = None, + limit: int = 50, + ) -> list[AnomalyAlertModel]: + """List alerts with rich filtering options. + + Leverages ix_anomaly_alerts_tenant_id_status, ix_anomaly_alerts_severity, + ix_anomaly_alerts_alert_type, and ix_anomaly_alerts_created_at indexes. + + Args: + agent_name: Optional agent name to filter by (requires join) + severity: Optional minimum severity to filter by + alert_type: Optional alert type to filter by + status: Optional status to filter by + from_date: Optional start date for created_at filter + to_date: Optional end date for created_at filter + limit: Maximum number of alerts to return + + Returns: + List of AnomalyAlertModel instances matching filters + """ + query = select(AnomalyAlertModel).where(AnomalyAlertModel.tenant_id == self.tenant_id) + + # Apply filters + if alert_type: + query = query.where(AnomalyAlertModel.alert_type == alert_type) + if severity is not None: + query = query.where(AnomalyAlertModel.severity >= severity) + if status: + query = query.where(AnomalyAlertModel.status == status) + if from_date: + query = query.where(AnomalyAlertModel.created_at >= from_date) + if to_date: + query = query.where(AnomalyAlertModel.created_at <= to_date) + + # Join with sessions if agent_name filter is provided + if agent_name: + from storage.models import SessionModel + + query = query.join(SessionModel, AnomalyAlertModel.session_id == SessionModel.id).where( + SessionModel.agent_name == agent_name + ) + + query = query.order_by(AnomalyAlertModel.created_at.desc()).limit(limit) + + result = await self.session.execute(query) + return list(result.scalars().all()) + + async def get_alert_lifecycle_summary(self) -> dict[str, Any]: + """Get alert summary statistics grouped by severity, type, and status. + + Uses ix_anomaly_alerts_severity, ix_anomaly_alerts_alert_type, and + ix_anomaly_alerts_tenant_id_status indexes for efficient grouping. + + Results are cached for 60 seconds. + + Returns: + Dictionary with counts by severity, type, and status + """ + cache_key = f"lifecycle_summary:{self.tenant_id}" + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + # Count by status + status_result = await self.session.execute( + select(AnomalyAlertModel.status, func.count(AnomalyAlertModel.id)) + .where(AnomalyAlertModel.tenant_id == self.tenant_id) + .group_by(AnomalyAlertModel.status) + ) + by_status = {status: count for status, count in status_result.all()} + + # Count by alert_type + type_result = await self.session.execute( + select(AnomalyAlertModel.alert_type, func.count(AnomalyAlertModel.id)) + .where(AnomalyAlertModel.tenant_id == self.tenant_id) + .group_by(AnomalyAlertModel.alert_type) + ) + by_type = {alert_type: count for alert_type, count in type_result.all()} + + # Count by severity ranges + critical_result = await self.session.execute( + select(func.count(AnomalyAlertModel.id)).where( + and_( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.severity >= 0.8, + ) + ) + ) + critical_count = critical_result.scalar() or 0 + + high_result = await self.session.execute( + select(func.count(AnomalyAlertModel.id)).where( + and_( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.severity >= 0.5, + AnomalyAlertModel.severity < 0.8, + ) + ) + ) + high_count = high_result.scalar() or 0 + + medium_result = await self.session.execute( + select(func.count(AnomalyAlertModel.id)).where( + and_( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.severity >= 0.3, + AnomalyAlertModel.severity < 0.5, + ) + ) + ) + medium_count = medium_result.scalar() or 0 + + low_result = await self.session.execute( + select(func.count(AnomalyAlertModel.id)).where( + and_( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.severity < 0.3, + ) + ) + ) + low_count = low_result.scalar() or 0 + + summary = { + "by_status": by_status, + "by_type": by_type, + "by_severity": { + "critical": critical_count, + "high": high_count, + "medium": medium_count, + "low": low_count, + }, + "total": sum(by_status.values()), + } + + # Cache for 60 seconds + self._cache.set(cache_key, summary, ttl_seconds=60) + return summary + + async def get_alert_trending(self, days: int = 7) -> list[dict[str, Any]]: + """Get alert volume trend grouped by day. + + Uses ix_anomaly_alerts_created_at index for efficient time-based grouping. + + Results are cached for 60 seconds. + + Args: + days: Number of days to look back (default 7) + + Returns: + List of dicts with date and count + """ + cache_key = f"trending:{self.tenant_id}:{days}d" + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + to_date = datetime.now(timezone.utc) + from_date = to_date - timedelta(days=days) + + result = await self.session.execute( + select( + func.date(AnomalyAlertModel.created_at).label("date"), + func.count(AnomalyAlertModel.id).label("count"), + ) + .where( + and_( + AnomalyAlertModel.tenant_id == self.tenant_id, + AnomalyAlertModel.created_at >= from_date, + ) + ) + .group_by(func.date(AnomalyAlertModel.created_at)) + .order_by(func.date(AnomalyAlertModel.created_at)) + ) + + trending = [{"date": str(row.date), "count": row.count} for row in result.all()] + + # Cache for 60 seconds + self._cache.set(cache_key, trending, ttl_seconds=60) + return trending diff --git a/storage/repositories/policy_repo.py b/storage/repositories/policy_repo.py new file mode 100644 index 0000000..1421a14 --- /dev/null +++ b/storage/repositories/policy_repo.py @@ -0,0 +1,219 @@ +"""Alert policy repository for policy CRUD operations.""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from storage.models import AlertPolicyModel + +_UNSET: Any = object() # Sentinel for "no update" to distinguish from None + + +class AlertPolicyRepository: + """Data access layer for alert policy CRUD operations. + + Provides async methods for policy management using SQLAlchemy async session. + All queries are scoped to a specific tenant_id for multi-tenant isolation. + """ + + def __init__(self, session: AsyncSession, tenant_id: str = "local"): + """Initialize the repository with an async session and tenant_id. + + Args: + session: SQLAlchemy AsyncSession instance + tenant_id: Tenant identifier for data isolation (default: "local") + """ + self.session = session + self.tenant_id = tenant_id + + async def create_policy( + self, + agent_name: str | None, + alert_type: str, + threshold_value: float, + severity_threshold: str | None = None, + enabled: bool = True, + ) -> AlertPolicyModel: + """Create a new alert policy. + + Args: + agent_name: Agent name for specific policy, None for global policy + alert_type: Type of alert this policy applies to + threshold_value: Threshold value for the alert + severity_threshold: Optional severity threshold (warning, critical, etc.) + enabled: Whether the policy is enabled + + Returns: + The created AlertPolicyModel instance + """ + policy = AlertPolicyModel( + id=str(uuid.uuid4()), + tenant_id=self.tenant_id, + agent_name=agent_name, + alert_type=alert_type, + threshold_value=threshold_value, + severity_threshold=severity_threshold, + enabled=enabled, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + self.session.add(policy) + return policy + + async def get_policy(self, policy_id: str) -> AlertPolicyModel | None: + """Retrieve an alert policy by ID. + + Args: + policy_id: Unique identifier of the policy + + Returns: + AlertPolicyModel if found, None otherwise + """ + result = await self.session.execute( + select(AlertPolicyModel).where( + AlertPolicyModel.id == policy_id, + AlertPolicyModel.tenant_id == self.tenant_id, + ) + ) + return result.scalar_one_or_none() + + async def list_policies( + self, + agent_name: str | None = None, + limit: int = 100, + ) -> list[AlertPolicyModel]: + """List alert policies, optionally filtered by agent_name. + + Args: + agent_name: Optional agent name filter (None returns all policies including global) + limit: Maximum number of policies to return + + Returns: + List of AlertPolicyModel instances + """ + query = select(AlertPolicyModel).where(AlertPolicyModel.tenant_id == self.tenant_id) + + if agent_name is not None: + # Filter for specific agent OR global (NULL agent_name) policies + query = query.where( + (AlertPolicyModel.agent_name == agent_name) | (AlertPolicyModel.agent_name.is_(None)) + ) + + query = query.order_by(AlertPolicyModel.created_at.desc()).limit(limit) + + result = await self.session.execute(query) + return list(result.scalars().all()) + + async def update_policy( + self, + policy_id: str, + agent_name: str | None | object = _UNSET, + alert_type: str | None | object = _UNSET, + threshold_value: float | None | object = _UNSET, + severity_threshold: str | None | object = _UNSET, + enabled: bool | None | object = _UNSET, + ) -> AlertPolicyModel | None: + """Update an existing alert policy. + + Uses _UNSET sentinel so that explicitly passing None for nullable + fields (e.g. agent_name=None to make a policy global) works correctly. + + Args: + policy_id: Unique identifier of the policy to update + agent_name: New agent name (_UNSET keeps existing, None makes global) + alert_type: New alert type (_UNSET keeps existing) + threshold_value: New threshold value (_UNSET keeps existing) + severity_threshold: New severity threshold (_UNSET keeps existing, None clears) + enabled: New enabled state (_UNSET keeps existing) + + Returns: + Updated AlertPolicyModel if found, None otherwise + """ + policy = await self.get_policy(policy_id) + if not policy: + return None + + if agent_name is not _UNSET: + policy.agent_name = agent_name + if alert_type is not _UNSET: + policy.alert_type = alert_type + if threshold_value is not _UNSET: + policy.threshold_value = threshold_value + if severity_threshold is not _UNSET: + policy.severity_threshold = severity_threshold + if enabled is not _UNSET: + policy.enabled = enabled + + policy.updated_at = datetime.now(timezone.utc) + return policy + + async def delete_policy(self, policy_id: str) -> bool: + """Delete an alert policy by ID. + + Args: + policy_id: Unique identifier of the policy to delete + + Returns: + True if policy was deleted, False if not found + """ + policy = await self.get_policy(policy_id) + if not policy: + return False + + await self.session.delete(policy) + return True + + async def get_active_policy_for( + self, + alert_type: str, + agent_name: str | None = None, + ) -> AlertPolicyModel | None: + """Get the active policy for a specific alert type and agent. + + Returns the most specific policy available: + 1. Agent-specific policy (if agent_name provided) + 2. Global policy for the alert type + 3. None if no policy found + + Args: + alert_type: Type of alert to find policy for + agent_name: Optional agent name for agent-specific policy + + Returns: + AlertPolicyModel if found, None otherwise + """ + # First try to find agent-specific policy + if agent_name is not None: + result = await self.session.execute( + select(AlertPolicyModel) + .where( + AlertPolicyModel.tenant_id == self.tenant_id, + AlertPolicyModel.alert_type == alert_type, + AlertPolicyModel.agent_name == agent_name, + AlertPolicyModel.enabled.is_(True), + ) + .order_by(AlertPolicyModel.created_at.desc()) + .limit(1) + ) + policy = result.scalar_one_or_none() + if policy: + return policy + + # Fall back to global policy + result = await self.session.execute( + select(AlertPolicyModel) + .where( + AlertPolicyModel.tenant_id == self.tenant_id, + AlertPolicyModel.alert_type == alert_type, + AlertPolicyModel.agent_name.is_(None), + AlertPolicyModel.enabled.is_(True), + ) + .order_by(AlertPolicyModel.created_at.desc()) + .limit(1) + ) + return result.scalar_one_or_none() diff --git a/storage/repository.py b/storage/repository.py index 23922f2..9894684 100644 --- a/storage/repository.py +++ b/storage/repository.py @@ -590,12 +590,12 @@ async def get_daily_cost_breakdown(self, days: int = 30) -> list[dict]: # Fill missing days with zeros breakdown = [] for i in range(days): - d = (period_start + timedelta(days=i)).date() + d = (period_start + timedelta(days=i)).strftime("%Y-%m-%d") if d in daily_data: breakdown.append(daily_data[d]) else: breakdown.append({ - "date": d.isoformat(), + "date": d, "session_count": 0, "total_cost_usd": 0.0, "total_tokens": 0, diff --git a/storage/search.py b/storage/search.py index 26cb265..9d89a89 100644 --- a/storage/search.py +++ b/storage/search.py @@ -174,6 +174,12 @@ async def search_sessions( Uses bag-of-words cosine similarity against session event embeddings. Searches across event_type, name, error_type, error_message, tool_name, and model fields. + Indexes used: + - ix_sessions_agent_name: for agent_name filtering + - ix_sessions_created_at: for time range filtering (started_at column) + - ix_events_event_type: for event_type subquery filtering + - ix_events_tenant_session: for tenant-scoped event lookups + Args: query: Search query text (supports natural language like "sessions with tool failures") status: Optional session status to filter by (e.g., "error", "completed") @@ -221,6 +227,11 @@ async def search_events( ) -> list[TraceEvent]: """Search events by name or data content. + Indexes used: + - ix_events_session_id_created_at: for session filtering and timestamp ordering + - ix_events_event_type: for event_type filtering + - ix_events_tenant_session: for tenant-scoped lookups + Args: query: Search string to match against event name session_id: Optional session ID to filter by diff --git a/tests/test_alert_lifecycle.py b/tests/test_alert_lifecycle.py new file mode 100644 index 0000000..772e084 --- /dev/null +++ b/tests/test_alert_lifecycle.py @@ -0,0 +1,365 @@ +"""Tests for alert lifecycle management API.""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from httpx import ASGITransport, AsyncClient + +from agent_debugger_sdk.core.events import Session +from api.main import create_app +from storage import TraceRepository +from storage.repository import AnomalyAlertCreate + + +def _make_session(session_id: str, agent_name: str = "test-agent") -> Session: + """Create a test session.""" + return Session( + id=session_id, + agent_name=agent_name, + framework="pytest", + started_at=datetime.now(timezone.utc), + config={"mode": "test"}, + ) + + +def _make_alert( + alert_id: str, + session_id: str, + alert_type: str = "error_spike", + severity: float = 0.8, +) -> AnomalyAlertCreate: + """Create a test alert.""" + return AnomalyAlertCreate( + id=alert_id, + session_id=session_id, + alert_type=alert_type, + severity=severity, + signal=f"Test alert {alert_id}", + event_ids=[f"event-{alert_id}"], + detection_source="test_detector", + detection_config={"test": True}, + ) + + +@pytest.mark.asyncio +async def test_update_alert_status_acknowledge(): + """Test updating an alert status to acknowledged.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-ack")) + alert = await repo.create_anomaly_alert( + _make_alert("test-alert-ack", session.id, severity=0.8) + ) + alert_id = alert.id + await db_session.commit() + + response = await client.put( + f"/api/alerts/{alert_id}/status", + json={"status": "acknowledged", "note": "Investigating this issue"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "acknowledged" + assert data["resolution_note"] == "Investigating this issue" + assert data["acknowledged_at"] is not None + + +@pytest.mark.asyncio +async def test_update_alert_status_resolve(): + """Test updating an alert status to resolved.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-resolve")) + alert = await repo.create_anomaly_alert( + _make_alert("test-alert-resolve", session.id, severity=0.7) + ) + alert_id = alert.id + await db_session.commit() + + response = await client.put( + f"/api/alerts/{alert_id}/status", + json={"status": "resolved", "note": "Fixed the root cause"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "resolved" + assert data["resolution_note"] == "Fixed the root cause" + assert data["resolved_at"] is not None + + +@pytest.mark.asyncio +async def test_update_alert_status_dismiss(): + """Test updating an alert status to dismissed.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-dismiss")) + alert = await repo.create_anomaly_alert( + _make_alert("test-alert-dismiss", session.id, severity=0.6) + ) + alert_id = alert.id + await db_session.commit() + + response = await client.put( + f"/api/alerts/{alert_id}/status", + json={"status": "dismissed", "note": "False alarm"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "dismissed" + assert data["resolution_note"] == "False alarm" + assert data["dismissed_at"] is not None + + +@pytest.mark.asyncio +async def test_update_alert_status_not_found(): + """Test updating a non-existent alert.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.put( + "/api/alerts/nonexistent-id/status", + json={"status": "acknowledged"}, + ) + + assert response.status_code == 404 + + +@pytest.mark.asyncio +async def test_bulk_update_alert_status(): + """Test bulk updating alert statuses.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-bulk")) + alert_ids = [] + for i in range(3): + alert = await repo.create_anomaly_alert( + _make_alert(f"test-alert-bulk-{i}", session.id, severity=0.5 + i * 0.1) + ) + alert_ids.append(alert.id) + await db_session.commit() + + response = await client.post( + "/api/alerts/bulk-status", + json={"alert_ids": alert_ids, "status": "acknowledged"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["updated"] == 3 + assert data["status"] == "acknowledged" + + +@pytest.mark.asyncio +async def test_get_alert_summary(): + """Test getting alert summary statistics.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Get baseline + baseline_resp = await client.get("/api/alerts/summary") + assert baseline_resp.status_code == 200 + baseline = baseline_resp.json() + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-summary")) + for i in range(3): + await repo.create_anomaly_alert( + _make_alert(f"test-alert-summary-{i}", session.id, severity=0.5 + i * 0.1) + ) + await db_session.commit() + + response = await client.get("/api/alerts/summary") + + assert response.status_code == 200 + data = response.json() + assert "by_status" in data + assert "by_type" in data + assert "by_severity" in data + assert "total" in data + assert data["total"] >= baseline["total"] + + +@pytest.mark.asyncio +async def test_get_alert_trending(): + """Test getting alert trending data.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/api/alerts/trending?days=7") + + assert response.status_code == 200 + data = response.json() + assert "trending" in data + assert "days" in data + assert data["days"] == 7 + assert isinstance(data["trending"], list) + + +@pytest.mark.asyncio +async def test_list_alerts_filtered_by_status(): + """Test filtering alerts by status.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-filter-status")) + alert_ids = [] + for i in range(3): + alert = await repo.create_anomaly_alert( + _make_alert(f"test-alert-filter-{i}", session.id, severity=0.5 + i * 0.1) + ) + alert_ids.append(alert.id) + await db_session.commit() + + # Acknowledge some alerts + await client.post( + "/api/alerts/bulk-status", + json={"alert_ids": alert_ids[:2], "status": "acknowledged"}, + ) + + # Now filter by status + response = await client.get("/api/alerts?status=active") + + assert response.status_code == 200 + data = response.json() + assert "alerts" in data + assert "total" in data + assert data["filters"]["status"] == "active" + + +@pytest.mark.asyncio +async def test_list_alerts_filtered_by_severity(): + """Test filtering alerts by minimum severity.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-filter-severity")) + for i in range(3): + await repo.create_anomaly_alert( + _make_alert(f"test-alert-sev-{i}", session.id, severity=0.4 + i * 0.2) + ) + await db_session.commit() + + response = await client.get("/api/alerts?severity=0.7") + + assert response.status_code == 200 + data = response.json() + assert "alerts" in data + assert data["filters"]["severity"] == 0.7 + # All returned alerts should have severity >= 0.7 + for alert in data["alerts"]: + assert alert["severity"] >= 0.7 + + +@pytest.mark.asyncio +async def test_list_alerts_filtered_by_type(): + """Test filtering alerts by alert type.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-filter-type")) + await repo.create_anomaly_alert( + _make_alert("test-alert-type-1", session.id, alert_type="error_spike") + ) + await repo.create_anomaly_alert( + _make_alert("test-alert-type-2", session.id, alert_type="confidence_drop") + ) + await db_session.commit() + + response = await client.get("/api/alerts?alert_type=error_spike") + + assert response.status_code == 200 + data = response.json() + assert "alerts" in data + assert data["filters"]["alert_type"] == "error_spike" + # All returned alerts should have the specified type + for alert in data["alerts"]: + assert alert["alert_type"] == "error_spike" + + +@pytest.mark.asyncio +async def test_alert_status_transitions(): + """Test alert status transitions: active -> acknowledged -> resolved.""" + app = create_app() + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + from api import app_context + + # Seed database with test data + async with app_context.require_session_maker()() as db_session: + repo = TraceRepository(db_session) + session = await repo.create_session(_make_session("test-session-transitions")) + alert = await repo.create_anomaly_alert( + _make_alert("test-alert-transition", session.id, severity=0.8) + ) + alert_id = alert.id + await db_session.commit() + + # Transition to acknowledged + response = await client.put( + f"/api/alerts/{alert_id}/status", + json={"status": "acknowledged", "note": "Looking into it"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "acknowledged" + assert data["acknowledged_at"] is not None + + # Transition to resolved + response = await client.put( + f"/api/alerts/{alert_id}/status", + json={"status": "resolved", "note": "Fixed"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "resolved" + assert data["resolved_at"] is not None + # Previous timestamp should still be present + assert data["acknowledged_at"] is not None diff --git a/tests/test_alert_policies.py b/tests/test_alert_policies.py new file mode 100644 index 0000000..6f53aae --- /dev/null +++ b/tests/test_alert_policies.py @@ -0,0 +1,268 @@ +"""Tests for alert policy repository and API endpoints.""" + +from __future__ import annotations + +import pytest + +from storage import AlertPolicyRepository + + +@pytest.mark.asyncio +async def test_create_policy(db_session): + """Test creating a new alert policy.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + policy = await repo.create_policy( + agent_name="test-agent", + alert_type="tool_loop", + threshold_value=3.0, + severity_threshold="high", + enabled=True, + ) + + assert policy.id is not None + assert policy.agent_name == "test-agent" + assert policy.alert_type == "tool_loop" + assert policy.threshold_value == 3.0 + assert policy.severity_threshold == "high" + assert policy.enabled is True + assert policy.tenant_id == "test-tenant" + + +@pytest.mark.asyncio +async def test_create_global_policy(db_session): + """Test creating a global policy (agent_name is None).""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + policy = await repo.create_policy( + agent_name=None, + alert_type="high_error_rate", + threshold_value=0.5, + enabled=True, + ) + + assert policy.agent_name is None + assert policy.alert_type == "high_error_rate" + + +@pytest.mark.asyncio +async def test_get_policy(db_session): + """Test retrieving a policy by ID.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + created = await repo.create_policy( + agent_name="test-agent", + alert_type="tool_loop", + threshold_value=3.0, + ) + await db_session.commit() + + retrieved = await repo.get_policy(created.id) + + assert retrieved is not None + assert retrieved.id == created.id + assert retrieved.alert_type == "tool_loop" + + +@pytest.mark.asyncio +async def test_get_policy_not_found(db_session): + """Test retrieving a non-existent policy returns None.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + retrieved = await repo.get_policy("non-existent-id") + + assert retrieved is None + + +@pytest.mark.asyncio +async def test_list_policies(db_session): + """Test listing all policies.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + # Create multiple policies + await repo.create_policy(agent_name="agent-1", alert_type="tool_loop", threshold_value=3.0) + await repo.create_policy(agent_name="agent-2", alert_type="high_error_rate", threshold_value=0.5) + await repo.create_policy(agent_name=None, alert_type="global_policy", threshold_value=1.0) + await db_session.commit() + + policies = await repo.list_policies() + + assert len(policies) == 3 + alert_types = {p.alert_type for p in policies} + assert "tool_loop" in alert_types + assert "high_error_rate" in alert_types + assert "global_policy" in alert_types + + +@pytest.mark.asyncio +async def test_list_policies_filtered_by_agent(db_session): + """Test listing policies filtered by agent_name.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + # Create policies for different agents + await repo.create_policy(agent_name="agent-1", alert_type="tool_loop", threshold_value=3.0) + await repo.create_policy(agent_name="agent-1", alert_type="high_error_rate", threshold_value=0.5) + await repo.create_policy(agent_name="agent-2", alert_type="tool_loop", threshold_value=5.0) + await repo.create_policy(agent_name=None, alert_type="tool_loop", threshold_value=2.0) + await db_session.commit() + + # List policies for agent-1 (should include agent-1 specific and global policies) + policies = await repo.list_policies(agent_name="agent-1") + + assert len(policies) == 3 # 2 agent-1 specific + 1 global + agent_names = {p.agent_name for p in policies} + assert "agent-1" in agent_names + assert None in agent_names # Global policy + + +@pytest.mark.asyncio +async def test_update_policy(db_session): + """Test updating an existing policy.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + created = await repo.create_policy( + agent_name="test-agent", + alert_type="tool_loop", + threshold_value=3.0, + severity_threshold="high", + enabled=True, + ) + await db_session.commit() + + updated = await repo.update_policy( + created.id, + threshold_value=5.0, + severity_threshold="critical", + enabled=False, + ) + await db_session.commit() + + assert updated is not None + assert updated.threshold_value == 5.0 + assert updated.severity_threshold == "critical" + assert updated.enabled is False + # Unchanged fields + assert updated.agent_name == "test-agent" + assert updated.alert_type == "tool_loop" + + +@pytest.mark.asyncio +async def test_update_policy_not_found(db_session): + """Test updating a non-existent policy returns None.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + updated = await repo.update_policy("non-existent-id", threshold_value=5.0) + + assert updated is None + + +@pytest.mark.asyncio +async def test_delete_policy(db_session): + """Test deleting a policy.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + created = await repo.create_policy( + agent_name="test-agent", + alert_type="tool_loop", + threshold_value=3.0, + ) + await db_session.commit() + + deleted = await repo.delete_policy(created.id) + await db_session.commit() + + assert deleted is True + + # Verify policy is gone + retrieved = await repo.get_policy(created.id) + assert retrieved is None + + +@pytest.mark.asyncio +async def test_delete_policy_not_found(db_session): + """Test deleting a non-existent policy returns False.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + deleted = await repo.delete_policy("non-existent-id") + + assert deleted is False + + +@pytest.mark.asyncio +async def test_get_active_policy_for_agent_specific(db_session): + """Test getting active policy prefers agent-specific over global.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + # Create both agent-specific and global policies + await repo.create_policy(agent_name="test-agent", alert_type="tool_loop", threshold_value=5.0) + await repo.create_policy(agent_name=None, alert_type="tool_loop", threshold_value=2.0) + await db_session.commit() + + # Should return agent-specific policy + policy = await repo.get_active_policy_for("tool_loop", agent_name="test-agent") + + assert policy is not None + assert policy.threshold_value == 5.0 + assert policy.agent_name == "test-agent" + + +@pytest.mark.asyncio +async def test_get_active_policy_falls_back_to_global(db_session): + """Test getting active policy falls back to global if no agent-specific policy.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + # Create only global policy + await repo.create_policy(agent_name=None, alert_type="tool_loop", threshold_value=2.0) + await db_session.commit() + + # Should return global policy + policy = await repo.get_active_policy_for("tool_loop", agent_name="test-agent") + + assert policy is not None + assert policy.threshold_value == 2.0 + assert policy.agent_name is None + + +@pytest.mark.asyncio +async def test_get_active_policy_disabled_not_returned(db_session): + """Test that disabled policies are not returned by get_active_policy_for.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + # Create disabled policy + await repo.create_policy( + agent_name=None, alert_type="tool_loop", threshold_value=2.0, enabled=False + ) + await db_session.commit() + + policy = await repo.get_active_policy_for("tool_loop", agent_name="test-agent") + + assert policy is None + + +@pytest.mark.asyncio +async def test_get_active_policy_no_policy_found(db_session): + """Test get_active_policy_for returns None when no policy exists.""" + repo = AlertPolicyRepository(db_session, tenant_id="test-tenant") + + policy = await repo.get_active_policy_for("non_existent_alert", agent_name="test-agent") + + assert policy is None + + +@pytest.mark.asyncio +async def test_policy_tenant_isolation(db_session): + """Test that policies are isolated by tenant_id.""" + repo1 = AlertPolicyRepository(db_session, tenant_id="tenant-1") + repo2 = AlertPolicyRepository(db_session, tenant_id="tenant-2") + + # Create policy in tenant-1 + await repo1.create_policy(agent_name="test-agent", alert_type="tool_loop", threshold_value=3.0) + await db_session.commit() + + # tenant-2 should not see tenant-1's policy + policies = await repo2.list_policies() + assert len(policies) == 0 + + # tenant-1 should see their policy + policies = await repo1.list_policies() + assert len(policies) == 1 diff --git a/tests/test_query_performance.py b/tests/test_query_performance.py new file mode 100644 index 0000000..86ab7c1 --- /dev/null +++ b/tests/test_query_performance.py @@ -0,0 +1,265 @@ +"""Tests for query performance optimizations. + +Tests verify that: +1. Database indexes exist on model definitions +2. Cache utility works correctly (set/get/invalidation/TTL) +3. Alert repository summary and trending methods use caching +4. Query patterns leverage available indexes +""" + +from __future__ import annotations + +import time +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker + +from storage.cache import QueryCache +from storage.models import AnomalyAlertModel, EventModel, PatternModel, SessionModel + + +class TestQueryCache: + """Test the QueryCache utility.""" + + def test_cache_set_get(self): + """Test basic cache set and get operations.""" + cache = QueryCache() + + cache.set("key1", "value1") + assert cache.get("key1") == "value1" + + cache.set("key2", {"nested": "dict"}) + assert cache.get("key2") == {"nested": "dict"} + + def test_cache_ttl_expiration(self): + """Test that cache entries expire after TTL without real sleeps.""" + cache = QueryCache() + base_time = time.time() + + cache.set("expiring_key", "value", ttl_seconds=1) + assert cache.get("expiring_key") == "value" + + # Advance time past TTL + with patch("storage.cache.time.time", return_value=base_time + 2): + assert cache.get("expiring_key") is None + + def test_cache_invalidation(self): + """Test manual cache invalidation.""" + cache = QueryCache() + + cache.set("key1", "value1") + cache.set("key2", "value2") + + assert cache.get("key1") == "value1" + assert cache.get("key2") == "value2" + + cache.invalidate("key1") + assert cache.get("key1") is None + assert cache.get("key2") == "value2" + + def test_cache_clear(self): + """Test clearing all cache entries.""" + cache = QueryCache() + + cache.set("key1", "value1") + cache.set("key2", "value2") + cache.set("key3", "value3") + + assert cache.size() == 3 + + cache.clear() + assert cache.size() == 0 + assert cache.get("key1") is None + assert cache.get("key2") is None + assert cache.get("key3") is None + + def test_cache_cleanup_expired(self): + """Test cleanup of expired entries without real sleeps.""" + cache = QueryCache() + base_time = time.time() + + # Set entries with different TTLs + cache.set("short", "value1", ttl_seconds=1) + cache.set("long", "value2", ttl_seconds=10) + + # Advance time past short TTL + with patch("storage.cache.time.time", return_value=base_time + 2): + removed = cache.cleanup_expired() + assert removed == 1 + assert cache.get("short") is None + assert cache.get("long") == "value2" + + def test_cache_size(self): + """Test cache size tracking.""" + cache = QueryCache() + + assert cache.size() == 0 + + cache.set("key1", "value1") + assert cache.size() == 1 + + cache.set("key2", "value2") + cache.set("key3", "value3") + assert cache.size() == 3 + + def test_cache_prefix_invalidation(self): + """Test prefix-based cache invalidation.""" + cache = QueryCache() + + cache.set("alert_summary:local:24h", {"total": 5}) + cache.set("alert_summary:local:48h", {"total": 10}) + cache.set("trending:local:7d", []) + + removed = cache.invalidate("alert_summary:local:", prefix=True) + assert removed == 2 + assert cache.get("alert_summary:local:24h") is None + assert cache.get("alert_summary:local:48h") is None + assert cache.get("trending:local:7d") is not None + + +class TestModelIndexes: + """Test that indexes exist on model definitions.""" + + def test_anomaly_alert_model_indexes(self): + """Verify AnomalyAlertModel has the expected indexes.""" + # We can't verify migration-created indexes on the model itself, + # but we can verify the columns that should be indexed exist + assert hasattr(AnomalyAlertModel, "created_at") + assert hasattr(AnomalyAlertModel, "severity") + assert hasattr(AnomalyAlertModel, "alert_type") + assert hasattr(AnomalyAlertModel, "session_id") + + def test_pattern_model_indexes(self): + """Verify PatternModel has the expected indexes.""" + # Check that indexed columns exist + assert hasattr(PatternModel, "pattern_type") + assert hasattr(PatternModel, "status") + assert hasattr(PatternModel, "agent_name") + assert hasattr(PatternModel, "severity") + + def test_session_model_indexes(self): + """Verify SessionModel has the expected indexes.""" + assert hasattr(SessionModel, "started_at") + assert hasattr(SessionModel, "agent_name") + + def test_event_model_indexes(self): + """Verify EventModel has the expected indexes.""" + assert hasattr(EventModel, "session_id") + assert hasattr(EventModel, "timestamp") + assert hasattr(EventModel, "event_type") + + +class TestAlertRepositoryOptimizations: + """Test alert repository query optimizations.""" + + @pytest.mark.asyncio + async def test_alert_summary_caching(self): + """Test that alert summary results are cached.""" + from storage.repositories.alert_repo import AnomalyAlertRepository + + # Create an in-memory SQLite engine for testing + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + # Create tables + from storage.models import Base + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + # Create session + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with async_session() as session: + repo = AnomalyAlertRepository(session, tenant_id="test") + + # Clear any existing cache + repo._cache.clear() + + # Create test alerts + now = datetime.now(timezone.utc) + for i in range(5): + alert = AnomalyAlertModel( + id=str(uuid.uuid4()), + tenant_id="test", + session_id=str(uuid.uuid4()), + alert_type=f"test_type_{i % 2}", + severity=0.5 + (i * 0.1), + signal=f"Test alert {i}", + event_ids=[], + detection_source="test", + detection_config={}, + created_at=now - timedelta(hours=i), + ) + session.add(alert) + await session.commit() + + # First call should query database + summary1 = await repo.get_alert_summary(hours=24) + assert summary1["total_count"] == 5 + + # Second call should use cache + summary2 = await repo.get_alert_summary(hours=24) + assert summary2["total_count"] == 5 + + # Verify cache was used + assert repo._cache.size() > 0 + + @pytest.mark.asyncio + async def test_alert_list_limit_default(self): + """Test that list queries have reasonable limits.""" + from storage.repositories.alert_repo import AnomalyAlertRepository + + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + + from storage.models import Base + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with async_session() as session: + repo = AnomalyAlertRepository(session, tenant_id="test") + + # Create more alerts than the default limit + session_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + for i in range(100): + alert = AnomalyAlertModel( + id=str(uuid.uuid4()), + tenant_id="test", + session_id=session_id, + alert_type="test_type", + severity=0.5, + signal=f"Test alert {i}", + event_ids=[], + detection_source="test", + detection_config={}, + created_at=now - timedelta(minutes=i), + ) + session.add(alert) + await session.commit() + + # List should return at most 50 (default limit) + alerts = await repo.list_anomaly_alerts(session_id) + assert len(alerts) <= 50 + + # Verify ordering by created_at desc (newest first) + if len(alerts) > 1: + assert alerts[0].created_at >= alerts[-1].created_at + + +@pytest.mark.integration +class TestMigrationIndexes: + """Integration tests for migration-created indexes.""" + + def test_migration_006_creates_indexes(self): + """Test that migration 006 creates the expected indexes.""" + # This test requires a real database connection + # Skip in unit test environments + pytest.skip("Requires database connection") + + # With a real connection, you would: + # 1. Run migration 006 + # 2. Query pg_indexes or sqlite_master to verify indexes exist + # 3. Verify index columns match expectations