diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a00b5e..2a4b9e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.2.1] - 2026-05-10 +## [0.2.4] - 2026-05-19 + +### Added +- **Health Check Hardening**: Provider health cross-referencing prevents mass-expiration of connections during transient upstream outages. +- **Bounded Concurrency**: Semaphore + WaitGroup pattern limits goroutine growth in both `HealthWorker` (max 10) and `ConnectionHealthWorker` (max 20). +- **Graceful Shutdown**: `--worker-only` mode now handles `SIGINT`/`SIGTERM` for clean process lifecycle management. +- **Frontend API**: New `GET /connections?workspace_id=` endpoint returns workspace-scoped connection summaries with health status. +- **Token Health Status**: `GET /connections/{id}/token` response now includes `health_status` field. +- **Database Index**: Partial index on `connections(status, last_health_check_at)` optimizes health check polling at scale. + +### Fixed +- `GET /providers/health` returns `[]` instead of `null` for empty provider lists. +- Standardized logging: replaced `fmt.Printf` with `log.Printf` in background workers. + +--- ### Changed - **Service Layer**: Refactored `connection_part2.go` into `credential.go`, separating credential capture, token refresh, and credential validation by responsibility. diff --git a/VERSION b/VERSION index 7179039..abd4105 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.2.3 +0.2.4 diff --git a/docs/healthchecks.md b/docs/healthchecks.md new file mode 100644 index 0000000..0f4511f --- /dev/null +++ b/docs/healthchecks.md @@ -0,0 +1,181 @@ +# Health Checks + +The Nexus Broker continuously monitors integration health across two dimensions: **provider-level** (is the upstream API alive?) and **connection-level** (is this user's credential still valid?). Both run as background workers inside the broker process. + +--- + +## Background Workers + +### HealthWorker — Provider-Level (5-minute interval) + +Probes every registered OAuth2 provider by sending a synthetic `invalid_grant` request to its `token_url`. This deliberately bad request tells us whether the provider's API is reachable and responding to OAuth traffic without requiring a real user credential. + +| Provider Response | Status Set | +|-------------------|------------| +| `400 Bad Request` or `401 Unauthorized` | `healthy` — API is alive and rejecting correctly | +| `5xx Server Error` | `unhealthy` — API is down | +| `200 OK` (unexpected for invalid grant) | `degraded` — API behaving abnormally | +| Network error / timeout | `unhealthy` | +| No `token_url` configured | `unknown` | + +For non-OAuth2 providers (API key, basic auth), the worker makes a `HEAD` request to `user_info_endpoint` or `api_base_url`. Any non-5xx response is treated as `healthy`. + +**Concurrency:** max 10 providers checked concurrently (semaphore + WaitGroup). + +--- + +### ConnectionHealthWorker — Connection-Level (1-minute interval) + +Validates individual user connections in batches of 100 on a fixed ticker, prioritising those never checked or longest overdue. Each check has a 15-second timeout. A shared `http.Client` is reused across checks for connection pooling. + +| Auth Type | Check Method | +|-----------|-------------| +| `oauth2` | Attempt a background token refresh via `ConnectionService.Refresh` | +| `api_key` | Decrypt credential, extract `api_key` field, `GET` to `user_info_endpoint` using provider's configured `AuthHeader` | +| `basic_auth` | Decrypt credential, extract `username`/`password`, `GET` to `user_info_endpoint` with `Authorization: Basic` | +| No endpoint configured | Mark `unknown` | + +**OAuth2 status code handling:** The worker inspects `RefreshResponse.StatusCode` to distinguish definitive credential errors from transient failures: + +| Upstream Status | `health_status` set | `connection.status` changed? | +|-----------------|--------------------|-----------------------------| +| Refresh succeeds | `healthy` | No | +| 400 / 401 (invalid_grant, revoked) | `expired` | Yes → `expired` (if provider healthy) | +| 403 (scope issue) | `degraded` | No | +| 5xx (upstream error) | `unhealthy` | No | +| Network error / nil response | `degraded` | No | + +**Provider shielding:** Before expiring a connection, the worker cross-references the upstream provider's `health_status`. If the provider is `unhealthy` or `degraded`, the connection is marked `unhealthy` (retriable) instead of `expired` (terminal). This prevents mass-expiration during transient upstream outages. + +**Error handling:** If `UpdateStatus` fails when expiring a connection, the worker logs the error and skips the `health_status` write to avoid leaving the connection in an inconsistent state. + +**Concurrency:** max 20 connections checked concurrently (semaphore + WaitGroup). + +--- + +## `health_status` Values + +Both `provider_profiles` and `connections` share the same status vocabulary: + +| Value | Meaning | +|-------|---------| +| `healthy` | Last check passed | +| `unhealthy` | Last check failed — retriable (transient upstream or provider-shielded) | +| `degraded` | Partial failure — scope issues, network errors, or internal errors where credential validity is unknown | +| `expired` | Credential confirmed invalid (400/401) — user must re-authenticate | +| `unknown` | Not yet checked, or not enough information to check | + +--- + +## API Endpoints + +### `GET /providers/health` +Returns the health status of all registered providers. No credentials are included. + +```http +GET /providers/health +Authorization: X-API-Key +``` + +```json +[ + { + "id": "uuid", + "name": "google", + "health_status": "healthy", + "last_health_check_at": "2026-05-19T07:00:00Z", + "health_message": "" + }, + { + "id": "uuid", + "name": "stripe", + "health_status": "unhealthy", + "last_health_check_at": "2026-05-19T07:05:00Z", + "health_message": "upstream returned 503" + } +] +``` + +Returns `[]` (not `null`) when no providers exist. + +--- + +### `GET /connections?workspace_id={workspace_id}` +Returns all non-pending connections for a workspace with health status. No credentials or tokens are included. + +```http +GET /connections?workspace_id=ws-123 +Authorization: X-API-Key +``` + +```json +[ + { + "id": "uuid", + "provider_id": "uuid", + "provider_name": "google", + "auth_type": "oauth2", + "status": "active", + "scopes": ["email", "calendar.read"], + "health_status": "healthy", + "last_health_check_at": "2026-05-19T07:00:00Z", + "created_at": "2026-05-01T00:00:00Z", + "updated_at": "2026-05-19T07:00:00Z" + } +] +``` + +**Use case:** Rendering a connections dashboard with live health indicators. + +--- + +### `GET /connections/{id}/token` (enhanced) +The existing token endpoint now includes `health_status` in its response alongside credentials and strategy. + +```json +{ + "strategy": { "type": "oauth2" }, + "credentials": { "access_token": "..." }, + "health_status": "healthy" +} +``` + +**Use case:** Showing an inline warning or re-auth prompt when consuming a credential. + +--- + +## Worker Mode + +Health workers run inside the standard broker process. For deployments that need to separate HTTP serving from background polling, pass `--worker-only` to the binary: + +```bash +nexus-broker --worker-only +``` + +In this mode, the HTTP server does not start. The process listens for `SIGINT`/`SIGTERM` and cancels the worker context, signalling in-flight checks to stop. Note: the current implementation does not explicitly wait for worker goroutines to complete before exiting. + +The same Docker image and environment variables are used — just override the container command. + +--- + +## Database Schema + +```sql +-- provider_profiles +ALTER TABLE provider_profiles + ADD COLUMN last_health_check_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN health_status VARCHAR(50) DEFAULT 'unknown', + ADD COLUMN health_message TEXT; + +-- connections +ALTER TABLE connections + ADD COLUMN last_health_check_at TIMESTAMP WITH TIME ZONE, + ADD COLUMN health_status VARCHAR(50) DEFAULT 'unknown'; + +-- Performance index for GetForHealthCheck query +CREATE INDEX IF NOT EXISTS idx_connections_health_check + ON connections (status, last_health_check_at ASC NULLS FIRST) + WHERE status = 'active'; +``` + +Migrations: `13_add_provider_health.sql`, `14_add_connection_health.sql`, `15_add_connection_health_index.sql`. \ No newline at end of file diff --git a/docs/services/broker.md b/docs/services/broker.md index 73f1272..d71c714 100644 --- a/docs/services/broker.md +++ b/docs/services/broker.md @@ -28,7 +28,14 @@ To ensure agents never face a "cold start" due to expired tokens: - It performs background refreshes using stored Refresh Tokens. - If a refresh fails permanently (e.g., user revoked access), it transitions the connection to `attention_required`. -### 5. Audit Subsystem +### 5. Health Monitoring +The Broker runs two background workers to continuously monitor integration health: +- **`HealthWorker`** (5-min interval): Probes all registered OAuth2 providers using a synthetic `invalid_grant` request to their `token_url`. A `400`/`401` response confirms the provider is alive; a `5xx` marks it `unhealthy`. +- **`ConnectionHealthWorker`** (1-min interval): Validates each active user connection by attempting a token refresh (OAuth2) or a lightweight API call (API key/basic auth). Uses **provider-shielding** to avoid falsely expiring connections during upstream outages. +- Both workers use bounded concurrency (semaphore + WaitGroup) to prevent goroutine exhaustion. +- In `--worker-only` mode, the binary listens for `SIGINT`/`SIGTERM` for graceful shutdown. + +### 6. Audit Subsystem Every control-plane mutation is recorded in the `audit_events` table via the `audit.Service`: - **`provider.created`** — logged on every successful `POST /providers` call. - **`provider.updated`** — logged on `PUT` and `PATCH` mutations. @@ -42,6 +49,18 @@ Audit events capture the **caller IP** (respecting `X-Forwarded-For`), **User-Ag See the [Audit Log Reference](../reference/audit-log.md) for how to query events. +## Key API Endpoints + +| Method | Path | Description | +|--------|------|-------------| +| `GET` | `/providers/health` | Provider health dashboard (all providers, no credentials) | +| `GET` | `/connections?workspace_id=` | All connections for a workspace with health status | +| `GET` | `/connections/{id}/token` | Resolve credentials + `health_status` for a specific connection | +| `POST` | `/connections/{id}/refresh` | Force a token refresh | +| `GET` | `/connections/resolve` | Resolve by `workspace_id` + `provider_name` | + +See [Health Checks Architecture](../healthchecks.md) for details on the monitoring system. + ## Environment Variables | Variable | Description | Default | diff --git a/mkdocs.yml b/mkdocs.yml index 0ee54b4..4a74597 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -95,6 +95,7 @@ nav: - Agent Identity: concepts/agent-identity.md - Security Model: concepts/security-model.md - Client Libraries: concepts/client-libraries.md + - Health Checks: healthchecks.md - Getting Started: - Deploy in Five Minutes: getting-started/quickstart.md - Configuration: getting-started/configuration.md @@ -122,7 +123,7 @@ nav: extra: - version: "0.2.3" + version: "0.2.4" social: - icon: material/web link: https://developers.prescottdata.io diff --git a/nexus-broker/cmd/nexus-broker/main.go b/nexus-broker/cmd/nexus-broker/main.go index 93cf761..50abf2f 100644 --- a/nexus-broker/cmd/nexus-broker/main.go +++ b/nexus-broker/cmd/nexus-broker/main.go @@ -4,6 +4,8 @@ import ( "context" "log" "os" + "os/signal" + "syscall" "time" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/audit" @@ -25,9 +27,17 @@ import ( var Version = "dev" func main() { - if len(os.Args) > 1 && (os.Args[1] == "-v" || os.Args[1] == "--version") { - log.Printf("Nexus Broker version: %s", Version) - os.Exit(0) + isWorkerOnly := false + if len(os.Args) > 1 { + for _, arg := range os.Args[1:] { + if arg == "-v" || arg == "--version" { + log.Printf("Nexus Broker version: %s", Version) + os.Exit(0) + } + if arg == "--worker-only" { + isWorkerOnly = true + } + } } cfg, err := config.Load() @@ -92,6 +102,7 @@ func main() { Audit: auditSvc, }) auditHandler := handlers.NewAuditHandler(db) + connectionsHandler := handlers.NewConnectionsHandler(connSvc) router := srv.Router() router.Get("/auth/callback", callbackHandler.Handle) @@ -107,6 +118,7 @@ func main() { protected.Route("/providers", func(r chi.Router) { r.Post("/", providersHandler.Register) r.Get("/", providersHandler.List) + r.Get("/health", providersHandler.Health) r.Get("/metadata", providersHandler.Metadata) r.Get("/by-name/{name}", providersHandler.GetByName) r.Delete("/by-name/{name}", providersHandler.DeleteByName) @@ -116,6 +128,7 @@ func main() { r.Delete("/{id}", providersHandler.Delete) }) protected.Post("/auth/consent-spec", consentHandler.GetSpec) + protected.Get("/connections", connectionsHandler.List) protected.Get("/connections/resolve", callbackHandler.ResolveToken) protected.Get("/connections/{connectionID}/token", callbackHandler.GetToken) protected.Post("/connections/{connectionID}/refresh", callbackHandler.Refresh) @@ -126,14 +139,35 @@ func main() { defer cleanupCancel() go handlers.StartOrphanTokenCleanup(cleanupCtx, db, 1*time.Hour) + // Start provider health worker (polls every 5m) + healthWorker := provider.NewHealthWorker(store, 5*time.Minute) + go healthWorker.Start(cleanupCtx) + + // Start connection health worker (polls every 1m) + // The store implements ProviderHealthLookup via GetHealthStatus(uuid.UUID) + connHealthWorker := service.NewConnectionHealthWorker(connRepo, connSvc, store, 1*time.Minute) + go connHealthWorker.Start(cleanupCtx) + // Start connection health gauge (polls every 30s) telemetry.NewConnectionGaugeCollector(connRepo, 30*time.Second) - log.Printf("Starting OAuth Broker server on port %s", cfg.Port) - log.Printf("Version: %s", Version) - log.Printf("Base URL: %s", cfg.BaseURL) - - if err := srv.Start(); err != nil { - log.Fatal("Server failed to start:", err) + if isWorkerOnly { + log.Printf("Starting Nexus Broker in WORKER-ONLY mode") + log.Printf("Version: %s", Version) + + // Wait for OS signal for graceful shutdown + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigCh + log.Printf("Received signal %v, shutting down workers...", sig) + cleanupCancel() + } else { + log.Printf("Starting OAuth Broker server on port %s", cfg.Port) + log.Printf("Version: %s", Version) + log.Printf("Base URL: %s", cfg.BaseURL) + + if err := srv.Start(); err != nil { + log.Fatal("Server failed to start:", err) + } } } diff --git a/nexus-broker/internal/domain/models.go b/nexus-broker/internal/domain/models.go index d220cfa..78db167 100644 --- a/nexus-broker/internal/domain/models.go +++ b/nexus-broker/internal/domain/models.go @@ -18,6 +18,8 @@ type Connection struct { ReturnURL string Status string ExpiresAt time.Time + LastHealthCheckAt sql.NullTime + HealthStatus string } // ConnectionWithProvider joins connection and basic provider info @@ -31,6 +33,21 @@ type ConnectionWithProvider struct { ProviderParams *json.RawMessage } +// ConnectionSummary is a lightweight view of a connection for frontend listing. +// It deliberately omits credentials and internal fields. +type ConnectionSummary struct { + ID uuid.UUID `json:"id"` + ProviderID uuid.UUID `json:"provider_id"` + ProviderName string `json:"provider_name"` + AuthType string `json:"auth_type"` + Status string `json:"status"` + Scopes []string `json:"scopes"` + HealthStatus string `json:"health_status"` + LastHealthCheckAt *time.Time `json:"last_health_check_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + // Token represents an encrypted token at rest type Token struct { ConnectionID uuid.UUID diff --git a/nexus-broker/internal/repository/instrumented/instrumented.go b/nexus-broker/internal/repository/instrumented/instrumented.go index a5f36b8..b539e8e 100644 --- a/nexus-broker/internal/repository/instrumented/instrumented.go +++ b/nexus-broker/internal/repository/instrumented/instrumented.go @@ -73,6 +73,22 @@ func (r *ConnectionRepository) GetActiveByWorkspaceAndProvider(ctx context.Conte defer observe("connection", "GetActiveByWorkspaceAndProvider", time.Now()) return r.inner.GetActiveByWorkspaceAndProvider(ctx, workspaceID, providerName) } + +func (r *ConnectionRepository) GetForHealthCheck(ctx context.Context, limit int) ([]*domain.ConnectionWithProvider, error) { + defer observe("connection", "GetForHealthCheck", time.Now()) + return r.inner.GetForHealthCheck(ctx, limit) +} + +func (r *ConnectionRepository) UpdateHealthStatus(ctx context.Context, id uuid.UUID, status string) error { + defer observe("connection", "UpdateHealthStatus", time.Now()) + return r.inner.UpdateHealthStatus(ctx, id, status) +} + +func (r *ConnectionRepository) ListByWorkspace(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) { + defer observe("connection", "ListByWorkspace", time.Now()) + return r.inner.ListByWorkspace(ctx, workspaceID) +} + // --- TokenRepository decorator --- // TokenRepository wraps repository.TokenRepository with latency instrumentation. diff --git a/nexus-broker/internal/repository/interfaces.go b/nexus-broker/internal/repository/interfaces.go index caf6923..275a2f4 100644 --- a/nexus-broker/internal/repository/interfaces.go +++ b/nexus-broker/internal/repository/interfaces.go @@ -16,6 +16,9 @@ type ConnectionRepository interface { UpdateStatus(ctx context.Context, id uuid.UUID, status string) error CountByStatus(ctx context.Context) (map[string]int64, error) GetActiveByWorkspaceAndProvider(ctx context.Context, workspaceID, providerName string) (*domain.ConnectionWithProvider, error) + GetForHealthCheck(ctx context.Context, limit int) ([]*domain.ConnectionWithProvider, error) + UpdateHealthStatus(ctx context.Context, id uuid.UUID, status string) error + ListByWorkspace(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) } // TokenRepository handles database operations for tokens diff --git a/nexus-broker/internal/repository/postgres/connection.go b/nexus-broker/internal/repository/postgres/connection.go index 951a423..a7326ee 100644 --- a/nexus-broker/internal/repository/postgres/connection.go +++ b/nexus-broker/internal/repository/postgres/connection.go @@ -44,12 +44,14 @@ func (r *connectionRepository) GetWithProvider(ctx context.Context, id uuid.UUID var conn domain.ConnectionWithProvider err := r.db.QueryRowContext(ctx, ` SELECT c.id, c.provider_id, c.status, c.scopes, c.return_url, - p.name, p.auth_type, COALESCE(p.auth_header, ''), COALESCE(p.api_base_url, ''), COALESCE(p.user_info_endpoint, ''), p.params + p.name, p.auth_type, COALESCE(p.auth_header, ''), COALESCE(p.api_base_url, ''), COALESCE(p.user_info_endpoint, ''), p.params, + COALESCE(c.health_status, 'unknown') FROM connections c JOIN provider_profiles p ON p.id = c.provider_id WHERE c.id = $1`, id). Scan(&conn.ID, &conn.ProviderID, &conn.Status, pq.Array(&conn.Scopes), &conn.ReturnURL, - &conn.ProviderName, &conn.AuthType, &conn.AuthHeader, &conn.APIBaseURL, &conn.UserInfoEndpoint, &conn.ProviderParams) + &conn.ProviderName, &conn.AuthType, &conn.AuthHeader, &conn.APIBaseURL, &conn.UserInfoEndpoint, &conn.ProviderParams, + &conn.HealthStatus) if err != nil { return nil, err } @@ -60,14 +62,16 @@ func (r *connectionRepository) GetActiveByWorkspaceAndProvider(ctx context.Conte var conn domain.ConnectionWithProvider err := r.db.QueryRowContext(ctx, ` SELECT c.id, c.provider_id, c.status, c.scopes, c.return_url, - p.name, p.auth_type, COALESCE(p.auth_header, ''), COALESCE(p.api_base_url, ''), COALESCE(p.user_info_endpoint, ''), p.params + p.name, p.auth_type, COALESCE(p.auth_header, ''), COALESCE(p.api_base_url, ''), COALESCE(p.user_info_endpoint, ''), p.params, + COALESCE(c.health_status, 'unknown') FROM connections c JOIN provider_profiles p ON p.id = c.provider_id WHERE c.workspace_id = $1 AND p.name = $2 AND c.status = 'active' ORDER BY c.updated_at DESC LIMIT 1`, workspaceID, providerName). Scan(&conn.ID, &conn.ProviderID, &conn.Status, pq.Array(&conn.Scopes), &conn.ReturnURL, - &conn.ProviderName, &conn.AuthType, &conn.AuthHeader, &conn.APIBaseURL, &conn.UserInfoEndpoint, &conn.ProviderParams) + &conn.ProviderName, &conn.AuthType, &conn.AuthHeader, &conn.APIBaseURL, &conn.UserInfoEndpoint, &conn.ProviderParams, + &conn.HealthStatus) if err != nil { return nil, err } @@ -103,3 +107,101 @@ func (r *connectionRepository) CountByStatus(ctx context.Context) (map[string]in } return counts, rows.Err() } + +func (r *connectionRepository) GetForHealthCheck(ctx context.Context, limit int) ([]*domain.ConnectionWithProvider, error) { + var rows []domain.ConnectionWithProvider + // Fetch active connections that haven't been checked in the last hour, + // or have never been checked, prioritizing the oldest checks first. + query := ` + SELECT c.id, c.workspace_id, c.provider_id, c.scopes, c.return_url, c.status, c.expires_at, + c.last_health_check_at, COALESCE(c.health_status, 'unknown'), + p.name, p.auth_type, COALESCE(p.auth_header, ''), COALESCE(p.api_base_url, ''), COALESCE(p.user_info_endpoint, ''), p.params + FROM connections c + JOIN provider_profiles p ON c.provider_id = p.id + WHERE c.status = 'active' + AND (c.last_health_check_at IS NULL OR c.last_health_check_at < NOW() - INTERVAL '1 hour') + ORDER BY c.last_health_check_at ASC NULLS FIRST + LIMIT $1 + ` + dbRows, err := r.db.QueryContext(ctx, query, limit) + if err != nil { + return nil, err + } + defer dbRows.Close() + + for dbRows.Next() { + var conn domain.ConnectionWithProvider + err := dbRows.Scan( + &conn.ID, &conn.WorkspaceID, &conn.ProviderID, pq.Array(&conn.Scopes), &conn.ReturnURL, &conn.Status, &conn.ExpiresAt, + &conn.LastHealthCheckAt, &conn.HealthStatus, + &conn.ProviderName, &conn.AuthType, &conn.AuthHeader, &conn.APIBaseURL, &conn.UserInfoEndpoint, &conn.ProviderParams, + ) + if err != nil { + return nil, err + } + rows = append(rows, conn) + } + + if err = dbRows.Err(); err != nil { + return nil, err + } + + // Returning pointers as per interface + var ptrRows []*domain.ConnectionWithProvider + for i := range rows { + ptrRows = append(ptrRows, &rows[i]) + } + + return ptrRows, nil +} + +func (r *connectionRepository) UpdateHealthStatus(ctx context.Context, id uuid.UUID, status string) error { + _, err := r.db.ExecContext(ctx, ` + UPDATE connections + SET health_status = $1, last_health_check_at = NOW(), updated_at = NOW() + WHERE id = $2`, status, id) + return err +} + +func (r *connectionRepository) ListByWorkspace(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) { + query := ` + SELECT c.id, c.provider_id, p.name, p.auth_type, c.status, c.scopes, + COALESCE(c.health_status, 'unknown'), c.last_health_check_at, + c.created_at, c.updated_at + FROM connections c + JOIN provider_profiles p ON c.provider_id = p.id AND p.deleted_at IS NULL + WHERE c.workspace_id = $1 AND c.status != 'pending' + ORDER BY c.updated_at DESC + ` + + rows, err := r.db.QueryContext(ctx, query, workspaceID) + if err != nil { + return nil, err + } + defer rows.Close() + + var summaries []domain.ConnectionSummary + for rows.Next() { + var s domain.ConnectionSummary + err := rows.Scan( + &s.ID, &s.ProviderID, &s.ProviderName, &s.AuthType, &s.Status, pq.Array(&s.Scopes), + &s.HealthStatus, &s.LastHealthCheckAt, + &s.CreatedAt, &s.UpdatedAt, + ) + if err != nil { + return nil, err + } + summaries = append(summaries, s) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + // Return empty slice instead of nil for clean JSON + if summaries == nil { + summaries = []domain.ConnectionSummary{} + } + + return summaries, nil +} diff --git a/nexus-broker/internal/service/connection.go b/nexus-broker/internal/service/connection.go index 22d9065..7c7fc34 100644 --- a/nexus-broker/internal/service/connection.go +++ b/nexus-broker/internal/service/connection.go @@ -31,6 +31,7 @@ type ConnectionService interface { GetCaptureSchema(ctx context.Context, state string) (string, json.RawMessage, error) SaveCredential(ctx context.Context, state string, credentials map[string]interface{}) (string, error) Refresh(ctx context.Context, connectionID uuid.UUID) (*RefreshResponse, error) + ListConnections(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) } type connectionService struct { @@ -410,10 +411,18 @@ func (s *connectionService) GetToken(ctx context.Context, connectionID uuid.UUID response["strategy"] = strategy response["credentials"] = credentials + response["health_status"] = conn.HealthStatus return response, conn.ProviderName, nil } +func (s *connectionService) ListConnections(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) { + if workspaceID == "" { + return nil, ErrBadRequest("missing_workspace_id", "workspace_id is required") + } + return s.connRepo.ListByWorkspace(ctx, workspaceID) +} + // Helpers func (s *connectionService) buildAuthURL(providerAuthURL, clientID, state, codeChallenge string, scopes []string, providerParams *json.RawMessage) (string, error) { diff --git a/nexus-broker/internal/service/connection_health.go b/nexus-broker/internal/service/connection_health.go new file mode 100644 index 0000000..f9ace6f --- /dev/null +++ b/nexus-broker/internal/service/connection_health.go @@ -0,0 +1,257 @@ +package service + +import ( + "context" + "encoding/base64" + "log" + "net/http" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/repository" +) + +// ProviderHealthLookup provides read-only access to provider health status. +// Uses a narrow query that only fetches health_status, avoiding loading +// sensitive fields (client_secret, params, etc.) into worker memory. +type ProviderHealthLookup interface { + GetHealthStatus(id uuid.UUID) (string, error) +} + +// ConnectionHealthWorker polls for active connections and verifies their health +type ConnectionHealthWorker struct { + connRepo repository.ConnectionRepository + connSvc ConnectionService + providerHealth ProviderHealthLookup + httpClient *http.Client + interval time.Duration + batchSize int + maxConcurrency int +} + +func NewConnectionHealthWorker( + connRepo repository.ConnectionRepository, + connSvc ConnectionService, + providerHealth ProviderHealthLookup, + interval time.Duration, +) *ConnectionHealthWorker { + return &ConnectionHealthWorker{ + connRepo: connRepo, + connSvc: connSvc, + providerHealth: providerHealth, + httpClient: &http.Client{Timeout: 10 * time.Second}, + interval: interval, + batchSize: 100, // Process 100 connections per interval + maxConcurrency: 20, // Limit to 20 concurrent health checks + } +} + +func (w *ConnectionHealthWorker) Start(ctx context.Context) { + ticker := time.NewTicker(w.interval) + defer ticker.Stop() + + // Run once immediately + w.runChecks(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + w.runChecks(ctx) + } + } +} + +func (w *ConnectionHealthWorker) runChecks(ctx context.Context) { + conns, err := w.connRepo.GetForHealthCheck(ctx, w.batchSize) + if err != nil { + log.Printf("ConnectionHealthWorker: failed to fetch connections: %v", err) + return + } + + if len(conns) == 0 { + return + } + + // Use a semaphore to bound concurrency + sem := make(chan struct{}, w.maxConcurrency) + var wg sync.WaitGroup + + for _, conn := range conns { + wg.Add(1) + sem <- struct{}{} // Acquire semaphore slot + + go func(c *domain.ConnectionWithProvider) { + defer wg.Done() + defer func() { <-sem }() // Release semaphore slot + + // A simple timeout context per check + checkCtx, cancel := context.WithTimeout(ctx, 15*time.Second) + defer cancel() + + status := w.checkConnection(checkCtx, c) + + // Only flip the connection's primary status to "expired" when we have + // a definitive credential error AND the upstream provider is healthy. + // For all other negative outcomes (unhealthy, degraded), we update + // health_status but leave the connection's primary status untouched + // to avoid overwriting states like "attention" set by the service layer. + if status == "expired" { + if w.isProviderDown(c.ProviderID) { + log.Printf("ConnectionHealthWorker: Connection %s refresh failed but provider %s is unhealthy — marking as unhealthy instead of expired", c.ID, c.ProviderName) + status = "unhealthy" + } else { + log.Printf("ConnectionHealthWorker: Connection %s for provider %s — credential definitively invalid, expiring", c.ID, c.ProviderName) + if err := w.connRepo.UpdateStatus(checkCtx, c.ID, "expired"); err != nil { + log.Printf("ConnectionHealthWorker: failed to expire connection %s — skipping health update to avoid inconsistent state: %v", c.ID, err) + return + } + } + } + + if err := w.connRepo.UpdateHealthStatus(checkCtx, c.ID, status); err != nil { + log.Printf("ConnectionHealthWorker: failed to update health status for conn %s: %v", c.ID, err) + } + }(conn) + } + + wg.Wait() +} + +// isProviderDown checks whether the upstream provider is currently experiencing issues. +// Returns true if the provider's health status is "unhealthy" or "degraded". +func (w *ConnectionHealthWorker) isProviderDown(providerID uuid.UUID) bool { + if w.providerHealth == nil { + return false // No lookup available, assume provider is fine + } + + status, err := w.providerHealth.GetHealthStatus(providerID) + if err != nil { + return false // Can't look up, assume provider is fine + } + + return status == "unhealthy" || status == "degraded" +} + +func (w *ConnectionHealthWorker) checkConnection(ctx context.Context, c *domain.ConnectionWithProvider) string { + if c.AuthType == "oauth2" { + return w.checkOAuth2Connection(ctx, c) + } + + // For non-OAuth2 (API keys), we need a UserInfoEndpoint to test against + if c.UserInfoEndpoint == "" { + return "unknown" + } + + // Fetch and decrypt the credentials + credentials, _, err := w.connSvc.GetToken(ctx, c.ID) + if err != nil { + // GetToken can fail for internal reasons (decryption error, DB error). + // Don't mark as expired — the credential might still be valid. + log.Printf("ConnectionHealthWorker: Connection %s — failed to fetch token: %v", c.ID, err) + return "degraded" + } + + // Make a test request to the user_info_endpoint + req, err := http.NewRequestWithContext(ctx, "GET", c.UserInfoEndpoint, nil) + if err != nil { + return "unhealthy" + } + + // Apply authentication using the same key extraction as validateCredentials + // in credential.go. We use explicit credential keys to avoid accidentally + // injecting unrelated map values (e.g., expires_at) as auth headers. + switch c.AuthType { + case "api_key": + apiKey, _ := credentials["api_key"].(string) + if apiKey == "" { + log.Printf("ConnectionHealthWorker: Connection %s — api_key field missing from credentials", c.ID) + return "degraded" + } + headerName := c.AuthHeader + if headerName == "" { + headerName = "Authorization" + } + if strings.ToLower(headerName) == "authorization" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } else { + req.Header.Set(headerName, apiKey) + } + + case "basic_auth": + username, _ := credentials["username"].(string) + password, _ := credentials["password"].(string) + if username == "" { + log.Printf("ConnectionHealthWorker: Connection %s — username field missing from credentials", c.ID) + return "degraded" + } + encoded := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + req.Header.Set("Authorization", "Basic "+encoded) + } + + resp, err := w.httpClient.Do(req) + if err != nil { + return "unhealthy" // Network failure + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return "expired" // The key is dead + } + + if resp.StatusCode >= 500 { + return "unhealthy" // Provider is having issues, don't mark as expired yet + } + + return "healthy" +} + +// checkOAuth2Connection inspects the RefreshResponse from the service layer to +// distinguish definitive credential errors from transient/internal failures. +// +// Status code mapping: +// +// Success → "healthy" +// 400/401 → "expired" (invalid_grant, token revoked — definitive) +// 403 → "degraded" (scope issues — credential exists but limited) +// 5xx → "unhealthy" (upstream issue — don't touch connection status) +// Network/internal → "degraded" (can't determine — don't touch connection status) +func (w *ConnectionHealthWorker) checkOAuth2Connection(ctx context.Context, c *domain.ConnectionWithProvider) string { + resp, err := w.connSvc.Refresh(ctx, c.ID) + if err == nil { + return "healthy" + } + + // Refresh returns a *RefreshResponse even on error, containing the upstream + // status code. Use it to make a precise determination. + if resp != nil && resp.StatusCode > 0 { + switch { + case resp.StatusCode == 400 || resp.StatusCode == 401: + // Definitive: invalid_grant, token revoked, client deauthorized. + // The service layer already set connection.status = "attention" for 4xx. + // We return "expired" so runChecks can flip to "expired" if provider is healthy. + return "expired" + case resp.StatusCode == 403: + // Partial revocation or scope downgrade. The refresh token may still be + // valid but scopes are reduced. Don't expire the connection. + return "degraded" + case resp.StatusCode >= 500: + // Upstream server error — transient. Don't touch the connection. + return "unhealthy" + default: + // Unexpected status (e.g., 429 rate limit). Treat as transient. + log.Printf("ConnectionHealthWorker: Connection %s — unexpected refresh status %d", c.ID, resp.StatusCode) + return "degraded" + } + } + + // No response at all — network error, DNS failure, timeout, or internal service + // error (decryption failure, missing provider, etc.). We can't determine whether + // the credential is valid, so mark degraded and leave connection.status untouched. + log.Printf("ConnectionHealthWorker: Connection %s — refresh error with no status code: %v", c.ID, err) + return "degraded" +} diff --git a/nexus-broker/internal/service/connection_health_test.go b/nexus-broker/internal/service/connection_health_test.go new file mode 100644 index 0000000..f39e599 --- /dev/null +++ b/nexus-broker/internal/service/connection_health_test.go @@ -0,0 +1,392 @@ +package service_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/service" +) + +// Add missing mock methods to MockConnectionRepository +func (m *MockConnectionRepository) GetForHealthCheck(ctx context.Context, limit int) ([]*domain.ConnectionWithProvider, error) { + args := m.Called(ctx, limit) + if args.Get(0) != nil { + return args.Get(0).([]*domain.ConnectionWithProvider), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *MockConnectionRepository) UpdateHealthStatus(ctx context.Context, id uuid.UUID, status string) error { + args := m.Called(ctx, id, status) + return args.Error(0) +} + +// MockConnectionService mocks the ConnectionService +type MockConnectionService struct { + mock.Mock +} + +func (m *MockConnectionService) CreateConsentSpec(ctx context.Context, req service.CreateConsentRequest) (*service.ConsentSpecResponse, error) { + args := m.Called(ctx, req) + if args.Get(0) != nil { + return args.Get(0).(*service.ConsentSpecResponse), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *MockConnectionService) ExchangeCodeForTokens(ctx context.Context, state, code, errorParam, errorDesc string) (string, bool, error) { + args := m.Called(ctx, state, code, errorParam, errorDesc) + return args.String(0), args.Bool(1), args.Error(2) +} + +func (m *MockConnectionService) GetToken(ctx context.Context, connectionID uuid.UUID) (map[string]interface{}, string, error) { + args := m.Called(ctx, connectionID) + if args.Get(0) != nil { + return args.Get(0).(map[string]interface{}), args.String(1), args.Error(2) + } + return nil, args.String(1), args.Error(2) +} + +func (m *MockConnectionService) GetTokenByWorkspaceAndProvider(ctx context.Context, workspaceID, providerName string) (map[string]interface{}, string, error) { + args := m.Called(ctx, workspaceID, providerName) + if args.Get(0) != nil { + return args.Get(0).(map[string]interface{}), args.String(1), args.Error(2) + } + return nil, args.String(1), args.Error(2) +} + +func (m *MockConnectionService) GetCaptureSchema(ctx context.Context, state string) (string, json.RawMessage, error) { + args := m.Called(ctx, state) + if args.Get(1) != nil { + return args.String(0), args.Get(1).(json.RawMessage), args.Error(2) + } + return args.String(0), nil, args.Error(2) +} + +func (m *MockConnectionService) SaveCredential(ctx context.Context, state string, credentials map[string]interface{}) (string, error) { + args := m.Called(ctx, state, credentials) + return args.String(0), args.Error(1) +} + +func (m *MockConnectionService) Refresh(ctx context.Context, connectionID uuid.UUID) (*service.RefreshResponse, error) { + args := m.Called(ctx, connectionID) + if args.Get(0) != nil { + return args.Get(0).(*service.RefreshResponse), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *MockConnectionService) ListConnections(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) { + args := m.Called(ctx, workspaceID) + if args.Get(0) != nil { + return args.Get(0).([]domain.ConnectionSummary), args.Error(1) + } + return nil, args.Error(1) +} + +// MockProviderHealthLookup mocks the ProviderHealthLookup interface +type MockProviderHealthLookup struct { + mock.Mock +} + +func (m *MockProviderHealthLookup) GetHealthStatus(id uuid.UUID) (string, error) { + args := m.Called(id) + return args.String(0), args.Error(1) +} + +func TestConnectionHealthWorker_OAuth2_Healthy(t *testing.T) { + mockRepo := new(MockConnectionRepository) + mockSvc := new(MockConnectionService) + mockHealth := new(MockProviderHealthLookup) + + connID := uuid.New() + conn := &domain.ConnectionWithProvider{ + Connection: domain.Connection{ + ID: connID, + Status: "active", + }, + AuthType: "oauth2", + } + + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{conn}, nil).Once() + // Should do nothing after the first call since we'll cancel the context + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{}, nil) + + // Mock successful refresh + mockSvc.On("Refresh", mock.Anything, connID).Return(&service.RefreshResponse{}, nil).Once() + + // Should update health to healthy + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "healthy").Return(nil).Once() + + worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + go worker.Start(ctx) + + time.Sleep(50 * time.Millisecond) // Give it time to run at least once + cancel() + + mockRepo.AssertExpectations(t) + mockSvc.AssertExpectations(t) +} + +func TestConnectionHealthWorker_OAuth2_Expired(t *testing.T) { + mockRepo := new(MockConnectionRepository) + mockSvc := new(MockConnectionService) + mockHealth := new(MockProviderHealthLookup) + + connID := uuid.New() + providerID := uuid.New() + conn := &domain.ConnectionWithProvider{ + Connection: domain.Connection{ + ID: connID, + ProviderID: providerID, + Status: "active", + }, + AuthType: "oauth2", + } + + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{conn}, nil).Once() + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{}, nil) + + // Mock failed refresh — 400 indicates definitive credential rejection + mockSvc.On("Refresh", mock.Anything, connID).Return(&service.RefreshResponse{StatusCode: 400}, errors.New("invalid_grant")).Once() + + // Provider is healthy, so the connection should be expired (not shielded) + mockHealth.On("GetHealthStatus", providerID).Return("healthy", nil).Once() + + // Should update connection status to expired + mockRepo.On("UpdateStatus", mock.Anything, connID, "expired").Return(nil).Once() + + // Should update health to expired + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "expired").Return(nil).Once() + + worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + go worker.Start(ctx) + + time.Sleep(50 * time.Millisecond) + cancel() + + mockRepo.AssertExpectations(t) + mockSvc.AssertExpectations(t) + mockHealth.AssertExpectations(t) +} + +func TestConnectionHealthWorker_OAuth2_ProviderDown_ShieldsExpiration(t *testing.T) { + mockRepo := new(MockConnectionRepository) + mockSvc := new(MockConnectionService) + mockHealth := new(MockProviderHealthLookup) + + connID := uuid.New() + providerID := uuid.New() + conn := &domain.ConnectionWithProvider{ + Connection: domain.Connection{ + ID: connID, + ProviderID: providerID, + Status: "active", + }, + ProviderName: "google", + AuthType: "oauth2", + } + + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{conn}, nil).Once() + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{}, nil) + + // Mock failed refresh — 401 is a definitive credential error, but should be + // shielded because the provider is unhealthy + mockSvc.On("Refresh", mock.Anything, connID).Return(&service.RefreshResponse{StatusCode: 401}, errors.New("token revoked")).Once() + + // Provider is unhealthy → should shield the connection from being expired + mockHealth.On("GetHealthStatus", providerID).Return("unhealthy", nil).Once() + + // Should NOT call UpdateStatus (no expiration) + // Should update health to "unhealthy" instead of "expired" + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "unhealthy").Return(nil).Once() + + worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + go worker.Start(ctx) + + time.Sleep(50 * time.Millisecond) + cancel() + + mockRepo.AssertExpectations(t) + mockSvc.AssertExpectations(t) + mockHealth.AssertExpectations(t) + // Verify UpdateStatus was NOT called — connection should not be expired + mockRepo.AssertNotCalled(t, "UpdateStatus", mock.Anything, connID, "expired") +} + +func TestConnectionHealthWorker_APIKey_Expired(t *testing.T) { + mockRepo := new(MockConnectionRepository) + mockSvc := new(MockConnectionService) + mockHealth := new(MockProviderHealthLookup) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer secret-key", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusUnauthorized) + })) + defer server.Close() + + connID := uuid.New() + providerID := uuid.New() + conn := &domain.ConnectionWithProvider{ + Connection: domain.Connection{ + ID: connID, + ProviderID: providerID, + Status: "active", + }, + AuthType: "api_key", + UserInfoEndpoint: server.URL, + } + + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{conn}, nil).Once() + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{}, nil) + + // Mock getting token + creds := map[string]interface{}{"api_key": "secret-key"} + mockSvc.On("GetToken", mock.Anything, connID).Return(creds, "api_key_strategy", nil).Once() + + // Provider is healthy, so expiration should proceed + mockHealth.On("GetHealthStatus", providerID).Return("healthy", nil).Once() + + // Should update connection status to expired + mockRepo.On("UpdateStatus", mock.Anything, connID, "expired").Return(nil).Once() + + // Should update health to expired + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "expired").Return(nil).Once() + + worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + go worker.Start(ctx) + + time.Sleep(50 * time.Millisecond) + cancel() + + mockRepo.AssertExpectations(t) + mockSvc.AssertExpectations(t) + mockHealth.AssertExpectations(t) +} + +func TestConnectionHealthWorker_OAuth2_Upstream5xx_MarksUnhealthy(t *testing.T) { + mockRepo := new(MockConnectionRepository) + mockSvc := new(MockConnectionService) + mockHealth := new(MockProviderHealthLookup) + + connID := uuid.New() + conn := &domain.ConnectionWithProvider{ + Connection: domain.Connection{ + ID: connID, + Status: "active", + }, + AuthType: "oauth2", + } + + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{conn}, nil).Once() + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{}, nil) + + // 502 from upstream — transient server error, not a credential issue + mockSvc.On("Refresh", mock.Anything, connID).Return(&service.RefreshResponse{StatusCode: 502}, errors.New("bad gateway")).Once() + + // Should set health_status to "unhealthy", NOT "expired" + // Should NOT call UpdateStatus — connection status stays "active" + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "unhealthy").Return(nil).Once() + + worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + go worker.Start(ctx) + + time.Sleep(50 * time.Millisecond) + cancel() + + mockRepo.AssertExpectations(t) + mockSvc.AssertExpectations(t) + mockRepo.AssertNotCalled(t, "UpdateStatus", mock.Anything, connID, "expired") +} + +func TestConnectionHealthWorker_OAuth2_403_MarksDegraded(t *testing.T) { + mockRepo := new(MockConnectionRepository) + mockSvc := new(MockConnectionService) + mockHealth := new(MockProviderHealthLookup) + + connID := uuid.New() + conn := &domain.ConnectionWithProvider{ + Connection: domain.Connection{ + ID: connID, + Status: "active", + }, + AuthType: "oauth2", + } + + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{conn}, nil).Once() + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{}, nil) + + // 403 — scope issue, credential exists but limited + mockSvc.On("Refresh", mock.Anything, connID).Return(&service.RefreshResponse{StatusCode: 403}, errors.New("forbidden")).Once() + + // Should set health_status to "degraded", NOT "expired" + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "degraded").Return(nil).Once() + + worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + go worker.Start(ctx) + + time.Sleep(50 * time.Millisecond) + cancel() + + mockRepo.AssertExpectations(t) + mockSvc.AssertExpectations(t) + mockRepo.AssertNotCalled(t, "UpdateStatus", mock.Anything, connID, "expired") +} + +func TestConnectionHealthWorker_OAuth2_NetworkError_MarksDegraded(t *testing.T) { + mockRepo := new(MockConnectionRepository) + mockSvc := new(MockConnectionService) + mockHealth := new(MockProviderHealthLookup) + + connID := uuid.New() + conn := &domain.ConnectionWithProvider{ + Connection: domain.Connection{ + ID: connID, + Status: "active", + }, + AuthType: "oauth2", + } + + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{conn}, nil).Once() + mockRepo.On("GetForHealthCheck", mock.Anything, 100).Return([]*domain.ConnectionWithProvider{}, nil) + + // Nil response — network error, timeout, DNS failure etc. + mockSvc.On("Refresh", mock.Anything, connID).Return((*service.RefreshResponse)(nil), errors.New("connection refused")).Once() + + // Should set health_status to "degraded" (we don't know if credential is valid) + mockRepo.On("UpdateHealthStatus", mock.Anything, connID, "degraded").Return(nil).Once() + + worker := service.NewConnectionHealthWorker(mockRepo, mockSvc, mockHealth, 10*time.Millisecond) + + ctx, cancel := context.WithCancel(context.Background()) + go worker.Start(ctx) + + time.Sleep(50 * time.Millisecond) + cancel() + + mockRepo.AssertExpectations(t) + mockSvc.AssertExpectations(t) + mockRepo.AssertNotCalled(t, "UpdateStatus", mock.Anything, connID, "expired") +} diff --git a/nexus-broker/internal/service/connection_test.go b/nexus-broker/internal/service/connection_test.go index d05dddb..fdecb8f 100644 --- a/nexus-broker/internal/service/connection_test.go +++ b/nexus-broker/internal/service/connection_test.go @@ -74,6 +74,14 @@ func (m *MockConnectionRepository) GetActiveByWorkspaceAndProvider(ctx context.C return nil, args.Error(1) } +func (m *MockConnectionRepository) ListByWorkspace(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) { + args := m.Called(ctx, workspaceID) + if args.Get(0) != nil { + return args.Get(0).([]domain.ConnectionSummary), args.Error(1) + } + return nil, args.Error(1) +} + // MockTokenRepository is a mock of repository.TokenRepository type MockTokenRepository struct { mock.Mock @@ -149,6 +157,22 @@ func (m *MockProfileStorer) ListProfiles() ([]provider.ProfileList, error) { return nil, args.Error(1) } +func (m *MockProfileStorer) GetAllProfiles() ([]provider.Profile, error) { + args := m.Called() + if args.Get(0) != nil { + return args.Get(0).([]provider.Profile), args.Error(1) + } + return nil, args.Error(1) +} + +func (m *MockProfileStorer) GetAllHealthStatuses() ([]provider.ProviderHealthSummary, error) { + args := m.Called() + if args.Get(0) != nil { + return args.Get(0).([]provider.ProviderHealthSummary), args.Error(1) + } + return nil, args.Error(1) +} + func (m *MockProfileStorer) GetMetadata() (map[string]map[string]interface{}, error) { args := m.Called() if args.Get(0) != nil { diff --git a/nexus-broker/migrations/13_add_provider_health.sql b/nexus-broker/migrations/13_add_provider_health.sql new file mode 100644 index 0000000..c29d3a1 --- /dev/null +++ b/nexus-broker/migrations/13_add_provider_health.sql @@ -0,0 +1,6 @@ +-- Add health tracking fields to provider_profiles + +ALTER TABLE provider_profiles +ADD COLUMN last_health_check_at TIMESTAMP WITH TIME ZONE, +ADD COLUMN health_status VARCHAR(50) DEFAULT 'unknown', +ADD COLUMN health_message TEXT; diff --git a/nexus-broker/migrations/14_add_connection_health.sql b/nexus-broker/migrations/14_add_connection_health.sql new file mode 100644 index 0000000..cc8377d --- /dev/null +++ b/nexus-broker/migrations/14_add_connection_health.sql @@ -0,0 +1,5 @@ +-- Add health tracking fields to connections + +ALTER TABLE connections +ADD COLUMN last_health_check_at TIMESTAMP WITH TIME ZONE, +ADD COLUMN health_status VARCHAR(50) DEFAULT 'unknown'; diff --git a/nexus-broker/migrations/15_add_connection_health_index.sql b/nexus-broker/migrations/15_add_connection_health_index.sql new file mode 100644 index 0000000..e6b8403 --- /dev/null +++ b/nexus-broker/migrations/15_add_connection_health_index.sql @@ -0,0 +1,7 @@ +-- Add composite index for health check query performance +-- Covers the WHERE status = 'active' AND last_health_check_at < ... pattern +-- used by GetForHealthCheck + +CREATE INDEX IF NOT EXISTS idx_connections_health_check +ON connections (status, last_health_check_at ASC NULLS FIRST) +WHERE status = 'active'; diff --git a/nexus-broker/openapi.yaml b/nexus-broker/openapi.yaml index 356476d..bc89dfb 100644 --- a/nexus-broker/openapi.yaml +++ b/nexus-broker/openapi.yaml @@ -1,7 +1,7 @@ openapi: 3.0.3 info: title: Nexus Broker API - version: 0.2.2 + version: 0.2.4 description: | Internal API for the Nexus Broker service. This service handles OAuth 2.0 and OIDC flows, encrypts tokens, and manages provider configurations. @@ -151,6 +151,62 @@ components: type: object description: Decrypted credentials map additionalProperties: true + health_status: + type: string + enum: [healthy, unhealthy, degraded, expired, unknown] + description: Current health status of the connection + + ProviderHealthStatus: + type: object + properties: + id: + type: string + format: uuid + name: + type: string + health_status: + type: string + enum: [healthy, unhealthy, degraded, unknown] + last_health_check_at: + type: string + format: date-time + nullable: true + health_message: + type: string + description: Human-readable detail when status is not healthy + + ConnectionSummary: + type: object + properties: + id: + type: string + format: uuid + provider_id: + type: string + format: uuid + provider_name: + type: string + auth_type: + type: string + enum: [oauth2, api_key, basic_auth] + status: + type: string + scopes: + type: array + items: { type: string } + health_status: + type: string + enum: [healthy, unhealthy, degraded, expired, unknown] + last_health_check_at: + type: string + format: date-time + nullable: true + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time MetadataResponse: type: object @@ -388,6 +444,44 @@ paths: '404': description: Connection not found + /providers/health: + get: + summary: Get health status of all providers + description: Returns the current health status of all registered providers. No credentials are included. + security: [{ ApiKeyAuth: [] }] + responses: + '200': + description: Provider health statuses. Returns `[]` when no providers exist. + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ProviderHealthStatus' + + /connections: + get: + summary: List connections for a workspace + description: Returns all non-pending connections for a workspace with their health status. No credentials or tokens are included. + security: [{ ApiKeyAuth: [] }] + parameters: + - in: query + name: workspace_id + required: true + schema: { type: string } + description: Workspace ID to filter connections by + responses: + '200': + description: List of connection summaries + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/ConnectionSummary' + '400': + description: Missing workspace_id parameter + /connections/resolve: get: summary: Resolve and retrieve token by workspace and provider diff --git a/nexus-broker/pkg/handlers/connections.go b/nexus-broker/pkg/handlers/connections.go new file mode 100644 index 0000000..bc3f735 --- /dev/null +++ b/nexus-broker/pkg/handlers/connections.go @@ -0,0 +1,37 @@ +package handlers + +import ( + "net/http" + + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/service" + "github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/httputil" +) + +// ConnectionsHandler handles connection-related API requests +type ConnectionsHandler struct { + svc service.ConnectionService +} + +// NewConnectionsHandler creates a new connections handler +func NewConnectionsHandler(svc service.ConnectionService) *ConnectionsHandler { + return &ConnectionsHandler{svc: svc} +} + +// List handles GET /connections?workspace_id=ws-123 +// Returns all non-pending connections for a workspace with health status. +func (h *ConnectionsHandler) List(w http.ResponseWriter, r *http.Request) { + workspaceID := r.URL.Query().Get("workspace_id") + + if workspaceID == "" { + httputil.WriteError(w, http.StatusBadRequest, "missing_workspace_id", "workspace_id query parameter is required") + return + } + + connections, err := h.svc.ListConnections(r.Context(), workspaceID) + if err != nil { + writeServiceError(w, err) + return + } + + httputil.WriteJSON(w, http.StatusOK, connections) +} diff --git a/nexus-broker/pkg/handlers/connections_test.go b/nexus-broker/pkg/handlers/connections_test.go new file mode 100644 index 0000000..ec8bc8a --- /dev/null +++ b/nexus-broker/pkg/handlers/connections_test.go @@ -0,0 +1,105 @@ +package handlers + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" +) + +func TestConnectionsList_Success(t *testing.T) { + mockSvc := new(MockConnectionService) + handler := NewConnectionsHandler(mockSvc) + + now := time.Now() + expected := []domain.ConnectionSummary{ + { + ID: uuid.New(), + ProviderID: uuid.New(), + ProviderName: "google", + AuthType: "oauth2", + Status: "active", + Scopes: []string{"email", "calendar.read"}, + HealthStatus: "healthy", + CreatedAt: now, + UpdatedAt: now, + }, + { + ID: uuid.New(), + ProviderID: uuid.New(), + ProviderName: "stripe", + AuthType: "api_key", + Status: "active", + HealthStatus: "unhealthy", + CreatedAt: now, + UpdatedAt: now, + }, + } + + mockSvc.On("ListConnections", mock.Anything, "ws-123").Return(expected, nil).Once() + + req := httptest.NewRequest("GET", "/connections?workspace_id=ws-123", nil) + rr := httptest.NewRecorder() + + handler.List(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), "google") + assert.Contains(t, rr.Body.String(), "stripe") + assert.Contains(t, rr.Body.String(), `"health_status":"healthy"`) + assert.Contains(t, rr.Body.String(), `"health_status":"unhealthy"`) + mockSvc.AssertExpectations(t) +} + +func TestConnectionsList_EmptyResult(t *testing.T) { + mockSvc := new(MockConnectionService) + handler := NewConnectionsHandler(mockSvc) + + mockSvc.On("ListConnections", mock.Anything, "ws-empty").Return([]domain.ConnectionSummary{}, nil).Once() + + req := httptest.NewRequest("GET", "/connections?workspace_id=ws-empty", nil) + rr := httptest.NewRecorder() + + handler.List(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "[]", rr.Body.String()) + mockSvc.AssertExpectations(t) +} + +func TestConnectionsList_MissingWorkspaceID(t *testing.T) { + mockSvc := new(MockConnectionService) + handler := NewConnectionsHandler(mockSvc) + + req := httptest.NewRequest("GET", "/connections", nil) + rr := httptest.NewRecorder() + + handler.List(rr, req) + + assert.Equal(t, http.StatusBadRequest, rr.Code) + assert.Contains(t, rr.Body.String(), "missing_workspace_id") + // Service should never be called + mockSvc.AssertNotCalled(t, "ListConnections") +} + +func TestConnectionsList_ServiceError(t *testing.T) { + mockSvc := new(MockConnectionService) + handler := NewConnectionsHandler(mockSvc) + + mockSvc.On("ListConnections", mock.Anything, "ws-broken").Return(nil, errors.New("database unreachable")).Once() + + req := httptest.NewRequest("GET", "/connections?workspace_id=ws-broken", nil) + rr := httptest.NewRecorder() + + handler.List(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + mockSvc.AssertExpectations(t) +} diff --git a/nexus-broker/pkg/handlers/consent_test.go b/nexus-broker/pkg/handlers/consent_test.go index 7c67c95..e20f3c2 100644 --- a/nexus-broker/pkg/handlers/consent_test.go +++ b/nexus-broker/pkg/handlers/consent_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/domain" "github.com/Prescott-Data/nexus-framework/nexus-broker/internal/service" ) @@ -69,6 +70,14 @@ func (m *MockConnectionService) Refresh(ctx context.Context, connectionID uuid.U return nil, args.Error(1) } +func (m *MockConnectionService) ListConnections(ctx context.Context, workspaceID string) ([]domain.ConnectionSummary, error) { + args := m.Called(ctx, workspaceID) + if args.Get(0) != nil { + return args.Get(0).([]domain.ConnectionSummary), args.Error(1) + } + return nil, args.Error(1) +} + func TestGetSpec_Success(t *testing.T) { mockSvc := new(MockConnectionService) handler := NewConsentHandler(ConsentHandlerConfig{ diff --git a/nexus-broker/pkg/handlers/providers.go b/nexus-broker/pkg/handlers/providers.go index bc607f9..de214a4 100644 --- a/nexus-broker/pkg/handlers/providers.go +++ b/nexus-broker/pkg/handlers/providers.go @@ -203,6 +203,17 @@ func (h *ProvidersHandler) List(w http.ResponseWriter, r *http.Request) { httputil.WriteJSON(w, http.StatusOK, rows) } +// Health handles GET /providers/health to list provider health statuses +func (h *ProvidersHandler) Health(w http.ResponseWriter, r *http.Request) { + summaries, err := h.store.GetAllHealthStatuses() + if err != nil { + httputil.WriteError(w, http.StatusInternalServerError, "health_failed", "Failed to retrieve provider health statuses") + return + } + + httputil.WriteJSON(w, http.StatusOK, summaries) +} + // GetByName handles GET /providers/by-name/{name} func (h *ProvidersHandler) GetByName(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") diff --git a/nexus-broker/pkg/handlers/providers_test.go b/nexus-broker/pkg/handlers/providers_test.go index da4f0b4..e813333 100644 --- a/nexus-broker/pkg/handlers/providers_test.go +++ b/nexus-broker/pkg/handlers/providers_test.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -75,6 +76,22 @@ func (m *MockStore) ListProfiles() ([]provider.ProfileList, error) { return args.Get(0).([]provider.ProfileList), args.Error(1) } +func (m *MockStore) GetAllProfiles() ([]provider.Profile, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]provider.Profile), args.Error(1) +} + +func (m *MockStore) GetAllHealthStatuses() ([]provider.ProviderHealthSummary, error) { + args := m.Called() + if args.Get(0) == nil { + return nil, args.Error(1) + } + return args.Get(0).([]provider.ProviderHealthSummary), args.Error(1) +} + func (m *MockStore) GetMetadata() (map[string]map[string]interface{}, error) { args := m.Called() if args.Get(0) == nil { @@ -270,3 +287,74 @@ func TestPatchProvider_AuditRedactsSecrets(t *testing.T) { return true }), mock.AnythingOfType("*http.Request")) } + +func TestHealth_Success(t *testing.T) { + mockStore := new(MockStore) + handler := NewProvidersHandler(mockStore, nil) + + now := time.Now() + msg := "token_url returned 503" + summaries := []provider.ProviderHealthSummary{ + { + ID: uuid.New(), + Name: "google", + HealthStatus: "healthy", + LastHealthCheckAt: &now, + HealthMessage: nil, + }, + { + ID: uuid.New(), + Name: "stripe", + HealthStatus: "unhealthy", + LastHealthCheckAt: &now, + HealthMessage: &msg, + }, + } + + mockStore.On("GetAllHealthStatuses").Return(summaries, nil).Once() + + req := httptest.NewRequest("GET", "/providers/health", nil) + rr := httptest.NewRecorder() + + handler.Health(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Contains(t, rr.Body.String(), `"health_status":"healthy"`) + assert.Contains(t, rr.Body.String(), `"health_status":"unhealthy"`) + assert.Contains(t, rr.Body.String(), `"health_message":"token_url returned 503"`) + assert.Contains(t, rr.Body.String(), "google") + assert.Contains(t, rr.Body.String(), "stripe") + mockStore.AssertExpectations(t) +} + +func TestHealth_EmptyList(t *testing.T) { + mockStore := new(MockStore) + handler := NewProvidersHandler(mockStore, nil) + + mockStore.On("GetAllHealthStatuses").Return([]provider.ProviderHealthSummary{}, nil).Once() + + req := httptest.NewRequest("GET", "/providers/health", nil) + rr := httptest.NewRecorder() + + handler.Health(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) + assert.Equal(t, "[]", rr.Body.String()) + mockStore.AssertExpectations(t) +} + +func TestHealth_StoreError(t *testing.T) { + mockStore := new(MockStore) + handler := NewProvidersHandler(mockStore, nil) + + mockStore.On("GetAllHealthStatuses").Return(nil, errors.New("connection refused")).Once() + + req := httptest.NewRequest("GET", "/providers/health", nil) + rr := httptest.NewRecorder() + + handler.Health(rr, req) + + assert.Equal(t, http.StatusInternalServerError, rr.Code) + assert.Contains(t, rr.Body.String(), "health_failed") + mockStore.AssertExpectations(t) +} diff --git a/nexus-broker/pkg/handlers/soc2_compliance_test.go b/nexus-broker/pkg/handlers/soc2_compliance_test.go index 2fab623..10406dc 100644 --- a/nexus-broker/pkg/handlers/soc2_compliance_test.go +++ b/nexus-broker/pkg/handlers/soc2_compliance_test.go @@ -93,8 +93,8 @@ func TestSOC2_CC61_EncryptionAtRest(t *testing.T) { // 2. Mock database expectations — parameterized queries only mock.ExpectQuery("SELECT c.id, c.provider_id"). WithArgs(connID). - WillReturnRows(sqlmock.NewRows([]string{"id", "provider_id", "status", "scopes", "return_url", "name", "auth_type", "auth_header", "api_base_url", "user_info_endpoint", "params"}). - AddRow(connID.String(), providerID.String(), "active", "{}", "http://localhost/return", "TestProvider", "api_key", "", "", "", nil)) + WillReturnRows(sqlmock.NewRows([]string{"id", "provider_id", "status", "scopes", "return_url", "name", "auth_type", "auth_header", "api_base_url", "user_info_endpoint", "params", "health_status"}). + AddRow(connID.String(), providerID.String(), "active", "{}", "http://localhost/return", "TestProvider", "api_key", "", "", "", nil, "unknown")) mock.ExpectExec("INSERT INTO tokens"). WithArgs(connID, sqlmock.AnyArg(), sqlmock.AnyArg()). diff --git a/nexus-broker/pkg/provider/health.go b/nexus-broker/pkg/provider/health.go new file mode 100644 index 0000000..31ffb02 --- /dev/null +++ b/nexus-broker/pkg/provider/health.go @@ -0,0 +1,200 @@ +package provider + +import ( + "context" + "fmt" + "io" + "log" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/Prescott-Data/nexus-framework/nexus-broker/pkg/discovery" +) + +// HealthWorker periodically checks the health of all registered providers +type HealthWorker struct { + store *Store + client *http.Client + interval time.Duration + maxConcurrency int +} + +func NewHealthWorker(store *Store, interval time.Duration) *HealthWorker { + return &HealthWorker{ + store: store, + client: &http.Client{ + Timeout: 10 * time.Second, // Prevent hanging requests + }, + interval: interval, + maxConcurrency: 10, // Limit to 10 concurrent provider checks + } +} + +func (w *HealthWorker) Start(ctx context.Context) { + ticker := time.NewTicker(w.interval) + defer ticker.Stop() + + // Run once immediately on start + w.runChecks(ctx) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + w.runChecks(ctx) + } + } +} + +func (w *HealthWorker) runChecks(ctx context.Context) { + profiles, err := w.store.GetAllProfiles() + if err != nil { + log.Printf("HealthWorker: failed to get profiles: %v", err) + return + } + + // Use a semaphore to bound concurrency + sem := make(chan struct{}, w.maxConcurrency) + var wg sync.WaitGroup + + for _, p := range profiles { + wg.Add(1) + sem <- struct{}{} // Acquire semaphore slot + + // Run each check in a goroutine to prevent one slow provider from blocking others + go func(profile Profile) { + defer wg.Done() + defer func() { <-sem }() // Release semaphore slot + + status, message := w.checkProvider(ctx, profile) + + var msgPtr *string + if message != "" { + msgPtr = &message + } + + if err := w.store.UpdateHealthStatus(profile.ID, status, msgPtr); err != nil { + log.Printf("HealthWorker: failed to update status for %s: %v", profile.Name, err) + } + }(p) + } + + wg.Wait() +} + +func (w *HealthWorker) checkProvider(ctx context.Context, p Profile) (string, string) { + // If it's not OAuth2, perform a Tier 1 Reachability Check + if p.AuthType != "oauth2" && p.AuthType != "" { + return w.checkReachability(ctx, p) + } + + tokenURL := "" + + // Tier 1 / Discovery setup + if p.EnableDiscovery { + var issuer string + if p.Issuer != nil { + issuer = *p.Issuer + } + + // Attempt OIDC Discovery + md, err := discovery.Discover(ctx, w.client, discovery.Hint{ + Issuer: issuer, + AuthURL: "", // We might not have AuthURL if discovery is enabled + }) + + if err != nil { + return "unhealthy", fmt.Sprintf("OIDC Discovery failed: %v", err) + } + tokenURL = md.TokenEndpoint + } else if p.TokenURL != nil { + tokenURL = *p.TokenURL + } + + if tokenURL == "" { + return "unknown", "No token URL available to check" + } + + // Tier 2: Configuration Validation (Deep Check) + // We make a dummy token request. + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", "dummy_code_for_health_check") + + if p.ClientID != nil { + form.Set("client_id", *p.ClientID) + } + if p.ClientSecret != nil { + form.Set("client_secret", *p.ClientSecret) + } + + req, err := http.NewRequestWithContext(ctx, "POST", tokenURL, strings.NewReader(form.Encode())) + if err != nil { + return "unhealthy", fmt.Sprintf("Failed to create request: %v", err) + } + + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := w.client.Do(req) + if err != nil { + return "unhealthy", fmt.Sprintf("Network error reaching token endpoint: %v", err) + } + defer resp.Body.Close() + + // Read a snippet of the body for the message + bodyBytes, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + bodyStr := string(bodyBytes) + + // If the provider is healthy, it should recognize the valid format but reject the dummy code + // Usually this means a 400 Bad Request or 401 Unauthorized with an OAuth error + if resp.StatusCode == http.StatusBadRequest || resp.StatusCode == http.StatusUnauthorized { + return "healthy", "Token endpoint reachable and returning expected OAuth error" + } + + // If we get 5xx, the provider's API is down + if resp.StatusCode >= 500 { + return "unhealthy", fmt.Sprintf("Provider returned Server Error (%d)", resp.StatusCode) + } + + // Any other status code is unexpected but not definitively "down" + return "degraded", fmt.Sprintf("Unexpected status code %d: %s", resp.StatusCode, bodyStr) +} + +func (w *HealthWorker) checkReachability(ctx context.Context, p Profile) (string, string) { + targetURL := "" + if p.UserInfoEndpoint != "" { + targetURL = p.UserInfoEndpoint + } else if p.APIBaseURL != "" { + targetURL = p.APIBaseURL + } + + if targetURL == "" { + return "unknown", "No API Base URL or User Info Endpoint configured for reachability check" + } + + // Try a HEAD request first, some APIs might reject it with 405 Method Not Allowed, + // but even a 405 indicates the server is up. + req, err := http.NewRequestWithContext(ctx, "HEAD", targetURL, nil) + if err != nil { + return "unhealthy", fmt.Sprintf("Failed to create request: %v", err) + } + + resp, err := w.client.Do(req) + if err != nil { + return "unhealthy", fmt.Sprintf("Network error reaching endpoint: %v", err) + } + defer resp.Body.Close() + + // 5xx indicates a server error + if resp.StatusCode >= 500 { + return "unhealthy", fmt.Sprintf("Provider returned Server Error (%d)", resp.StatusCode) + } + + // 2xx, 3xx, 4xx all indicate the server is actively responding + return "healthy", fmt.Sprintf("Endpoint reachable (status %d)", resp.StatusCode) +} diff --git a/nexus-broker/pkg/provider/health_test.go b/nexus-broker/pkg/provider/health_test.go new file mode 100644 index 0000000..4d02a51 --- /dev/null +++ b/nexus-broker/pkg/provider/health_test.go @@ -0,0 +1,145 @@ +package provider + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// stringPtr is a helper to get a pointer to a string +func stringPtr(s string) *string { + return &s +} + +func TestHealthWorker_CheckProvider_NonOAuth2_MissingURLs(t *testing.T) { + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "api_key", + }) + assert.Equal(t, "unknown", status) + assert.Contains(t, msg, "No API Base URL or User Info Endpoint configured") +} + +func TestHealthWorker_CheckProvider_NonOAuth2_Healthy(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "HEAD", r.Method) + w.WriteHeader(http.StatusUnauthorized) // 401 is healthy for a reachability check + })) + defer server.Close() + + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "api_key", + APIBaseURL: server.URL, + }) + assert.Equal(t, "healthy", status) + assert.Contains(t, msg, "Endpoint reachable (status 401)") +} + +func TestHealthWorker_CheckProvider_NonOAuth2_Unhealthy(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) // 502 + })) + defer server.Close() + + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "api_key", + UserInfoEndpoint: server.URL, + }) + assert.Equal(t, "unhealthy", status) + assert.Contains(t, msg, "Server Error (502)") +} + +func TestHealthWorker_CheckProvider_MissingTokenURL(t *testing.T) { + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "oauth2", + TokenURL: nil, // explicitly nil + }) + assert.Equal(t, "unknown", status) + assert.Contains(t, msg, "No token URL available") +} + +func TestHealthWorker_CheckProvider_Healthy(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + + err := r.ParseForm() + assert.NoError(t, err) + assert.Equal(t, "authorization_code", r.FormValue("grant_type")) + assert.Equal(t, "dummy_code_for_health_check", r.FormValue("code")) + + w.WriteHeader(http.StatusBadRequest) // 400 is expected for dummy code + _, _ = w.Write([]byte(`{"error": "invalid_grant"}`)) + })) + defer server.Close() + + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "oauth2", + TokenURL: stringPtr(server.URL), + ClientID: stringPtr("test-client"), + ClientSecret: stringPtr("test-secret"), + }) + + assert.Equal(t, "healthy", status) + assert.Contains(t, msg, "expected OAuth error") +} + +func TestHealthWorker_CheckProvider_Unhealthy_500(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "oauth2", + TokenURL: stringPtr(server.URL), + }) + + assert.Equal(t, "unhealthy", status) + assert.Contains(t, msg, "Server Error (500)") +} + +func TestHealthWorker_CheckProvider_Degraded_200(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"access_token": "wait_what"}`)) + })) + defer server.Close() + + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "oauth2", + TokenURL: stringPtr(server.URL), + }) + + assert.Equal(t, "degraded", status) + assert.Contains(t, msg, "Unexpected status code 200") +} + +func TestHealthWorker_CheckProvider_NetworkError(t *testing.T) { + worker := NewHealthWorker(nil, time.Minute) + + status, msg := worker.checkProvider(context.Background(), Profile{ + AuthType: "oauth2", + TokenURL: stringPtr("http://localhost:12345/nonexistent-server-so-this-fails-to-connect"), + }) + + assert.Equal(t, "unhealthy", status) + assert.Contains(t, msg, "Network error reaching token endpoint") +} diff --git a/nexus-broker/pkg/provider/interfaces.go b/nexus-broker/pkg/provider/interfaces.go index 14e1747..61a538a 100644 --- a/nexus-broker/pkg/provider/interfaces.go +++ b/nexus-broker/pkg/provider/interfaces.go @@ -22,5 +22,7 @@ type ProfileStorer interface { // ... DeleteProfileByName(name string) (int64, error) ListProfiles() ([]ProfileList, error) + GetAllProfiles() ([]Profile, error) + GetAllHealthStatuses() ([]ProviderHealthSummary, error) GetMetadata() (map[string]map[string]interface{}, error) } diff --git a/nexus-broker/pkg/provider/store.go b/nexus-broker/pkg/provider/store.go index 8629a89..61d7948 100644 --- a/nexus-broker/pkg/provider/store.go +++ b/nexus-broker/pkg/provider/store.go @@ -42,6 +42,9 @@ type Profile struct { UserInfoEndpoint string `json:"user_info_endpoint,omitempty" db:"user_info_endpoint"` Params *json.RawMessage `json:"params,omitempty" db:"params"` DeletedAt *time.Time `json:"-" db:"deleted_at"` + LastHealthCheckAt *time.Time `json:"last_health_check_at,omitempty" db:"last_health_check_at"` + HealthStatus string `json:"health_status" db:"health_status"` + HealthMessage *string `json:"health_message,omitempty" db:"health_message"` } // RegisterProfile registers a new provider profile from JSON @@ -153,10 +156,10 @@ func (s *Store) RegisterProfile(profileJSON string) (*Profile, error) { // GetProfile retrieves a provider profile by ID func (s *Store) GetProfile(id uuid.UUID) (*Profile, error) { var p Profile - query := `SELECT id, name, client_id, client_secret, auth_url, token_url, issuer, enable_discovery, scopes, auth_type, COALESCE(auth_header, ''), COALESCE(api_base_url, ''), COALESCE(user_info_endpoint, ''), params, COALESCE(description, ''), COALESCE(category, '') FROM provider_profiles WHERE id = $1 AND deleted_at IS NULL` + query := `SELECT id, name, client_id, client_secret, auth_url, token_url, issuer, enable_discovery, scopes, auth_type, COALESCE(auth_header, ''), COALESCE(api_base_url, ''), COALESCE(user_info_endpoint, ''), params, COALESCE(description, ''), COALESCE(category, ''), last_health_check_at, COALESCE(health_status, 'unknown'), health_message FROM provider_profiles WHERE id = $1 AND deleted_at IS NULL` row := s.db.QueryRow(query, id) - err := row.Scan(&p.ID, &p.Name, &p.ClientID, &p.ClientSecret, &p.AuthURL, &p.TokenURL, &p.Issuer, &p.EnableDiscovery, pq.Array(&p.Scopes), &p.AuthType, &p.AuthHeader, &p.APIBaseURL, &p.UserInfoEndpoint, &p.Params, &p.Description, &p.Category) + err := row.Scan(&p.ID, &p.Name, &p.ClientID, &p.ClientSecret, &p.AuthURL, &p.TokenURL, &p.Issuer, &p.EnableDiscovery, pq.Array(&p.Scopes), &p.AuthType, &p.AuthHeader, &p.APIBaseURL, &p.UserInfoEndpoint, &p.Params, &p.Description, &p.Category, &p.LastHealthCheckAt, &p.HealthStatus, &p.HealthMessage) if err != nil { return nil, fmt.Errorf("failed to get provider profile: %w", err) } @@ -174,7 +177,8 @@ func (s *Store) GetProfileByName(name string) (*Profile, error) { SELECT id, name, client_id, client_secret, auth_url, token_url, issuer, enable_discovery, scopes, auth_type, COALESCE(auth_header, ''), COALESCE(api_base_url, ''), COALESCE(user_info_endpoint, ''), params, - COALESCE(description, ''), COALESCE(category, '') + COALESCE(description, ''), COALESCE(category, ''), last_health_check_at, + COALESCE(health_status, 'unknown'), health_message FROM provider_profiles WHERE LOWER(name) = $1 AND deleted_at IS NULL ` @@ -192,6 +196,7 @@ func (s *Store) GetProfileByName(name string) (*Profile, error) { &p.ID, &p.Name, &p.ClientID, &p.ClientSecret, &p.AuthURL, &p.TokenURL, &p.Issuer, &p.EnableDiscovery, pq.Array(&p.Scopes), &p.AuthType, &p.AuthHeader, &p.APIBaseURL, &p.UserInfoEndpoint, &p.Params, &p.Description, &p.Category, + &p.LastHealthCheckAt, &p.HealthStatus, &p.HealthMessage, ) if err != nil { return nil, fmt.Errorf("failed to scan provider profile: %w", err) @@ -369,8 +374,119 @@ func (s *Store) ListProfiles() ([]ProfileList, error) { return rows, nil } +// GetAllProfiles retrieves all non-deleted provider profiles in full +func (s *Store) GetAllProfiles() ([]Profile, error) { + query := ` + SELECT id, name, client_id, client_secret, auth_url, token_url, issuer, + enable_discovery, scopes, auth_type, COALESCE(auth_header, ''), + COALESCE(api_base_url, ''), COALESCE(user_info_endpoint, ''), params, + COALESCE(description, ''), COALESCE(category, ''), last_health_check_at, + COALESCE(health_status, 'unknown'), health_message + FROM provider_profiles + WHERE deleted_at IS NULL + ` + + rows, err := s.db.Query(query) + if err != nil { + return nil, fmt.Errorf("failed to query all profiles: %w", err) + } + defer rows.Close() + + var profiles []Profile + for rows.Next() { + var p Profile + err := rows.Scan( + &p.ID, &p.Name, &p.ClientID, &p.ClientSecret, &p.AuthURL, &p.TokenURL, + &p.Issuer, &p.EnableDiscovery, pq.Array(&p.Scopes), &p.AuthType, + &p.AuthHeader, &p.APIBaseURL, &p.UserInfoEndpoint, &p.Params, &p.Description, &p.Category, + &p.LastHealthCheckAt, &p.HealthStatus, &p.HealthMessage, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan provider profile: %w", err) + } + profiles = append(profiles, p) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating all profiles: %w", err) + } + + return profiles, nil +} + +// UpdateHealthStatus updates the health fields for a given provider profile +func (s *Store) UpdateHealthStatus(id uuid.UUID, status string, message *string) error { + query := `UPDATE provider_profiles SET health_status = $1, health_message = $2, last_health_check_at = NOW() WHERE id = $3 AND deleted_at IS NULL` + _, err := s.db.Exec(query, status, message, id) + if err != nil { + return fmt.Errorf("failed to update provider health status: %w", err) + } + return nil +} + +// GetHealthStatus returns only the health_status for a provider. +// This is a narrow query intended for background workers that need to +// cross-reference provider health without loading sensitive fields. +func (s *Store) GetHealthStatus(id uuid.UUID) (string, error) { + var status string + err := s.db.QueryRow( + `SELECT COALESCE(health_status, 'unknown') FROM provider_profiles WHERE id = $1 AND deleted_at IS NULL`, + id, + ).Scan(&status) + if err != nil { + return "", fmt.Errorf("failed to get provider health status: %w", err) + } + return status, nil +} + +// ProviderHealthSummary is a lightweight view containing only health-related fields. +// Used by the /providers/health endpoint to avoid loading sensitive columns. +type ProviderHealthSummary struct { + ID uuid.UUID `json:"id"` + Name string `json:"name"` + HealthStatus string `json:"health_status"` + LastHealthCheckAt *time.Time `json:"last_health_check_at,omitempty"` + HealthMessage *string `json:"health_message,omitempty"` +} + +// GetAllHealthStatuses returns health-only summaries for all active providers. +// This is a narrow query that avoids selecting sensitive fields (client_secret, params, etc.). +func (s *Store) GetAllHealthStatuses() ([]ProviderHealthSummary, error) { + query := ` + SELECT id, name, COALESCE(health_status, 'unknown'), last_health_check_at, health_message + FROM provider_profiles + WHERE deleted_at IS NULL + ` + + rows, err := s.db.Query(query) + if err != nil { + return nil, fmt.Errorf("failed to query provider health statuses: %w", err) + } + defer rows.Close() + + var summaries []ProviderHealthSummary + for rows.Next() { + var s ProviderHealthSummary + if err := rows.Scan(&s.ID, &s.Name, &s.HealthStatus, &s.LastHealthCheckAt, &s.HealthMessage); err != nil { + return nil, fmt.Errorf("failed to scan provider health summary: %w", err) + } + summaries = append(summaries, s) + } + + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating provider health statuses: %w", err) + } + + // Return empty slice instead of nil for clean JSON serialization + if summaries == nil { + summaries = []ProviderHealthSummary{} + } + + return summaries, nil +} // GetMetadata retrieves integration metadata for all providers, grouped by auth_type func (s *Store) GetMetadata() (map[string]map[string]interface{}, error) { + query := ` SELECT id, diff --git a/nexus-broker/pkg/provider/store_test.go b/nexus-broker/pkg/provider/store_test.go index c7b6c02..f45a2af 100644 --- a/nexus-broker/pkg/provider/store_test.go +++ b/nexus-broker/pkg/provider/store_test.go @@ -4,6 +4,7 @@ import ( "database/sql" "encoding/json" "testing" + "time" "github.com/google/uuid" "github.com/jmoiron/sqlx" @@ -191,10 +192,10 @@ func TestGetProfile_NullValues(t *testing.T) { providerID := uuid.New() rows := sqlmock.NewRows([]string{ "id", "name", "client_id", "client_secret", "auth_url", "token_url", "issuer", - "enable_discovery", "scopes", "auth_type", "auth_header", "api_base_url", "user_info_endpoint", "params", "description", "category", + "enable_discovery", "scopes", "auth_type", "auth_header", "api_base_url", "user_info_endpoint", "params", "description", "category", "last_health_check_at", "health_status", "health_message", }).AddRow( providerID.String(), "null-provider", nil, nil, nil, nil, nil, - false, []byte("{}"), "api_key", "", "", "", nil, "", "", + false, []byte("{}"), "api_key", "", "", "", nil, "", "", nil, "unknown", nil, ) mock.ExpectQuery(`SELECT .* FROM provider_profiles WHERE id = \$1`). @@ -208,3 +209,266 @@ func TestGetProfile_NullValues(t *testing.T) { assert.Equal(t, "null-provider", profile.Name) } } + +func TestGetAllProfiles_Success(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + id1 := uuid.New() + id2 := uuid.New() + now := time.Now() + msg := "timeout reaching token_endpoint" + + // Must match the exact 19-column order in GetAllProfiles SELECT: + // id, name, client_id, client_secret, auth_url, token_url, issuer, + // enable_discovery, scopes, auth_type, auth_header, + // api_base_url, user_info_endpoint, params, description, category, + // last_health_check_at, health_status, health_message + rows := sqlmock.NewRows([]string{ + "id", "name", "client_id", "client_secret", "auth_url", "token_url", "issuer", + "enable_discovery", "scopes", "auth_type", "auth_header", + "api_base_url", "user_info_endpoint", "params", "description", "category", + "last_health_check_at", "health_status", "health_message", + }).AddRow( + id1.String(), "google", ptr("cid"), ptr("csec"), ptr("https://auth"), ptr("https://token"), nil, + true, []byte("{email,profile}"), "oauth2", "", + "https://api.google.com", "/userinfo", nil, "Google OAuth", "Identity", + now, "healthy", nil, + ).AddRow( + id2.String(), "stripe", nil, nil, nil, nil, nil, + false, []byte("{}"), "api_key", "Authorization", + "https://api.stripe.com", "/v1/account", nil, "Stripe API", "Payments", + now, "unhealthy", &msg, + ) + + mock.ExpectQuery(`SELECT .* FROM provider_profiles`).WillReturnRows(rows) + + profiles, err := store.GetAllProfiles() + assert.NoError(t, err) + assert.Len(t, profiles, 2) + + // Verify first profile health fields + assert.Equal(t, id1, profiles[0].ID) + assert.Equal(t, "google", profiles[0].Name) + assert.Equal(t, "healthy", profiles[0].HealthStatus) + assert.NotNil(t, profiles[0].LastHealthCheckAt) + assert.Nil(t, profiles[0].HealthMessage) + + // Verify second profile health fields + assert.Equal(t, id2, profiles[1].ID) + assert.Equal(t, "stripe", profiles[1].Name) + assert.Equal(t, "unhealthy", profiles[1].HealthStatus) + assert.NotNil(t, profiles[1].HealthMessage) + assert.Equal(t, "timeout reaching token_endpoint", *profiles[1].HealthMessage) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetAllProfiles_Empty(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + rows := sqlmock.NewRows([]string{ + "id", "name", "client_id", "client_secret", "auth_url", "token_url", "issuer", + "enable_discovery", "scopes", "auth_type", "auth_header", + "api_base_url", "user_info_endpoint", "params", "description", "category", + "last_health_check_at", "health_status", "health_message", + }) + + mock.ExpectQuery(`SELECT .* FROM provider_profiles`).WillReturnRows(rows) + + profiles, err := store.GetAllProfiles() + assert.NoError(t, err) + assert.Nil(t, profiles) // append to nil slice returns nil + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetAllProfiles_QueryError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + mock.ExpectQuery(`SELECT .* FROM provider_profiles`). + WillReturnError(sql.ErrConnDone) + + profiles, err := store.GetAllProfiles() + assert.Error(t, err) + assert.Nil(t, profiles) + assert.Contains(t, err.Error(), "failed to query all profiles") + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUpdateHealthStatus_Success(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + providerID := uuid.New() + msg := "token_url 503" + + // Verify the UPDATE is called with (status, message, id) in correct order + mock.ExpectExec(`UPDATE provider_profiles SET health_status = \$1, health_message = \$2, last_health_check_at = NOW\(\) WHERE id = \$3`). + WithArgs("unhealthy", &msg, providerID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err = store.UpdateHealthStatus(providerID, "unhealthy", &msg) + assert.NoError(t, err) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUpdateHealthStatus_NilMessage(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + providerID := uuid.New() + + mock.ExpectExec(`UPDATE provider_profiles SET health_status = \$1, health_message = \$2, last_health_check_at = NOW\(\) WHERE id = \$3`). + WithArgs("healthy", nil, providerID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err = store.UpdateHealthStatus(providerID, "healthy", nil) + assert.NoError(t, err) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUpdateHealthStatus_DBError(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + providerID := uuid.New() + + mock.ExpectExec(`UPDATE provider_profiles SET health_status`). + WithArgs("unhealthy", nil, providerID). + WillReturnError(sql.ErrConnDone) + + err = store.UpdateHealthStatus(providerID, "unhealthy", nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to update provider health status") + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetHealthStatus_Success(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + providerID := uuid.New() + rows := sqlmock.NewRows([]string{"health_status"}).AddRow("unhealthy") + + mock.ExpectQuery(`SELECT COALESCE\(health_status, 'unknown'\) FROM provider_profiles WHERE id = \$1`). + WithArgs(providerID). + WillReturnRows(rows) + + status, err := store.GetHealthStatus(providerID) + assert.NoError(t, err) + assert.Equal(t, "unhealthy", status) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetHealthStatus_NotFound(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + providerID := uuid.New() + + mock.ExpectQuery(`SELECT COALESCE\(health_status, 'unknown'\) FROM provider_profiles WHERE id = \$1`). + WithArgs(providerID). + WillReturnError(sql.ErrNoRows) + + status, err := store.GetHealthStatus(providerID) + assert.Error(t, err) + assert.Equal(t, "", status) + assert.Contains(t, err.Error(), "failed to get provider health status") + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetAllHealthStatuses_Success(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + id1 := uuid.New() + id2 := uuid.New() + now := time.Now() + msg := "503 from token_url" + + rows := sqlmock.NewRows([]string{"id", "name", "health_status", "last_health_check_at", "health_message"}). + AddRow(id1.String(), "google", "healthy", now, nil). + AddRow(id2.String(), "stripe", "unhealthy", now, &msg) + + mock.ExpectQuery(`SELECT id, name, COALESCE\(health_status, 'unknown'\), last_health_check_at, health_message FROM provider_profiles`). + WillReturnRows(rows) + + summaries, err := store.GetAllHealthStatuses() + assert.NoError(t, err) + assert.Len(t, summaries, 2) + + assert.Equal(t, "google", summaries[0].Name) + assert.Equal(t, "healthy", summaries[0].HealthStatus) + assert.Nil(t, summaries[0].HealthMessage) + + assert.Equal(t, "stripe", summaries[1].Name) + assert.Equal(t, "unhealthy", summaries[1].HealthStatus) + assert.Equal(t, "503 from token_url", *summaries[1].HealthMessage) + + assert.NoError(t, mock.ExpectationsWereMet()) +} + +func TestGetAllHealthStatuses_Empty(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + defer db.Close() + + sqlxDB := sqlx.NewDb(db, "sqlmock") + store := NewStore(sqlxDB) + + rows := sqlmock.NewRows([]string{"id", "name", "health_status", "last_health_check_at", "health_message"}) + mock.ExpectQuery(`SELECT id, name, COALESCE\(health_status, 'unknown'\), last_health_check_at, health_message FROM provider_profiles`). + WillReturnRows(rows) + + summaries, err := store.GetAllHealthStatuses() + assert.NoError(t, err) + assert.NotNil(t, summaries) // Should return [] not nil + assert.Len(t, summaries, 0) + + assert.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/nexus-broker/pkg/storage/pg.go b/nexus-broker/pkg/storage/pg.go index b09275d..48d0de4 100644 --- a/nexus-broker/pkg/storage/pg.go +++ b/nexus-broker/pkg/storage/pg.go @@ -42,6 +42,9 @@ type ProviderProfile struct { Params *json.RawMessage `db:"params" json:"params,omitempty"` CreatedAt time.Time `db:"created_at" json:"created_at"` UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LastHealthCheckAt *time.Time `db:"last_health_check_at" json:"last_health_check_at,omitempty"` + HealthStatus string `db:"health_status" json:"health_status"` + HealthMessage *string `db:"health_message" json:"health_message,omitempty"` } // Connection represents an OAuth connection flow @@ -53,9 +56,11 @@ type Connection struct { CodeVerifier *string `db:"code_verifier" json:"code_verifier,omitempty"` Scopes []string `db:"scopes" json:"scopes"` ReturnURL string `db:"return_url" json:"return_url"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - ExpiresAt time.Time `db:"expires_at" json:"expires_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + ExpiresAt time.Time `db:"expires_at" json:"expires_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + LastHealthCheckAt *time.Time `db:"last_health_check_at" json:"last_health_check_at,omitempty"` + HealthStatus string `db:"health_status" json:"health_status"` } // Token represents encrypted OAuth tokens