diff --git a/.env.example b/.env.example index 3650284f38..f907792d13 100644 --- a/.env.example +++ b/.env.example @@ -1071,7 +1071,6 @@ OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4317 # OAuth Configuration # OAUTH_REQUEST_TIMEOUT=30 # OAUTH_MAX_RETRIES=3 -# OAUTH_DEFAULT_TIMEOUT=3600 # OAuth Security Settings # When MCP servers require OAuth authorization code flow, diff --git a/.secrets.baseline b/.secrets.baseline index 48e3add9d3..fc262d9d1b 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -3,7 +3,7 @@ "files": "(?x)( package-lock\\.json$ |Cargo\\.lock$ |uv\\.lock$ |go\\.sum$ |mcpgateway/sri_hashes\\.json$ )|^.secrets.baseline$", "lines": null }, - "generated_at": "2026-04-26T13:45:17Z", + "generated_at": "2026-04-26T16:16:40Z", "plugins_used": [ { "name": "AWSKeyDetector" @@ -132,7 +132,7 @@ "hashed_secret": "ac371b6dcce28a86c90d12bc57d946a800eebf17", "is_secret": false, "is_verified": false, - "line_number": 1149, + "line_number": 1148, "type": "Secret Keyword", "verified_result": null }, @@ -140,7 +140,7 @@ "hashed_secret": "0b6ec68df700dec4dcd64babd0eda1edccddace1", "is_secret": false, "is_verified": false, - "line_number": 1154, + "line_number": 1153, "type": "Secret Keyword", "verified_result": null }, @@ -148,7 +148,7 @@ "hashed_secret": "4ad6f0082ee224001beb3ca5c3e81c8ceea5ed86", "is_secret": false, "is_verified": false, - "line_number": 1159, + "line_number": 1158, "type": "Secret Keyword", "verified_result": null }, @@ -156,7 +156,7 @@ "hashed_secret": "cb32747fcfb55eaa194c8cd8e4ba7d49ada08a94", "is_secret": false, "is_verified": false, - "line_number": 1165, + "line_number": 1164, "type": "Secret Keyword", "verified_result": null }, @@ -164,7 +164,7 @@ "hashed_secret": "6c178d51b13520496dbc767ed3d9d7aa5803ac72", "is_secret": false, "is_verified": false, - "line_number": 1177, + "line_number": 1176, "type": "Secret Keyword", "verified_result": null }, @@ -172,7 +172,7 @@ "hashed_secret": "ca45060a53fd8a255d1a83ee8d2f025283ccc66e", "is_secret": false, "is_verified": false, - "line_number": 1195, + "line_number": 1194, "type": "Secret Keyword", "verified_result": null }, @@ -180,7 +180,7 @@ "hashed_secret": "910fbf00f58e9bcb095ea26a75cc1d9a3355e671", "is_secret": false, "is_verified": false, - "line_number": 1256, + "line_number": 1255, "type": "Secret Keyword", "verified_result": null } @@ -546,7 +546,7 @@ "hashed_secret": "bc1a7c4dc707a3b61e2e9a345eec9e23674efa11", "is_secret": false, "is_verified": false, - "line_number": 1067, + "line_number": 1066, "type": "Secret Keyword", "verified_result": null }, @@ -554,7 +554,7 @@ "hashed_secret": "2df14e4719f299249cd9a97cf68cc87232a27cbb", "is_secret": false, "is_verified": false, - "line_number": 1374, + "line_number": 1373, "type": "Hex High Entropy String", "verified_result": null } @@ -618,7 +618,7 @@ "hashed_secret": "25ab86bed149ca6ca9c1c0d5db7c9a91388ddeab", "is_secret": false, "is_verified": false, - "line_number": 969, + "line_number": 968, "type": "Basic Auth Credentials", "verified_result": null }, @@ -626,7 +626,7 @@ "hashed_secret": "d08f88df745fa7950b104e4a707a31cfce7b5841", "is_secret": false, "is_verified": false, - "line_number": 1072, + "line_number": 1071, "type": "Secret Keyword", "verified_result": null }, @@ -634,7 +634,7 @@ "hashed_secret": "7288edd0fc3ffcbe93a0cf06e3568e28521687bc", "is_secret": false, "is_verified": false, - "line_number": 1075, + "line_number": 1074, "type": "Secret Keyword", "verified_result": null }, @@ -642,7 +642,7 @@ "hashed_secret": "8674c9b302d20800e4ab3808f139704d8641a6e3", "is_secret": false, "is_verified": false, - "line_number": 1241, + "line_number": 1240, "type": "Secret Keyword", "verified_result": null }, @@ -650,7 +650,7 @@ "hashed_secret": "cff0d14e4337fa8bdb68dfa906f04b0df6fad72f", "is_secret": false, "is_verified": false, - "line_number": 1280, + "line_number": 1279, "type": "Secret Keyword", "verified_result": null }, @@ -658,7 +658,7 @@ "hashed_secret": "f865b53623b121fd34ee5426c792e5c33af8c227", "is_secret": false, "is_verified": false, - "line_number": 1329, + "line_number": 1328, "type": "Secret Keyword", "verified_result": null }, @@ -666,7 +666,7 @@ "hashed_secret": "acde39840735314af1300688b6c2324ea89770a3", "is_secret": false, "is_verified": false, - "line_number": 1424, + "line_number": 1423, "type": "Secret Keyword", "verified_result": null }, @@ -674,7 +674,7 @@ "hashed_secret": "fa9beb99e4029ad5a6615399e7bbae21356086b3", "is_secret": false, "is_verified": false, - "line_number": 1795, + "line_number": 1794, "type": "Secret Keyword", "verified_result": null } @@ -2348,7 +2348,7 @@ "hashed_secret": "4a0a2df96d4c9a13a282268cab33ac4b8cbb2c72", "is_secret": false, "is_verified": false, - "line_number": 196, + "line_number": 195, "type": "Secret Keyword", "verified_result": null }, @@ -2356,7 +2356,7 @@ "hashed_secret": "bc2f74c22f98f7b6ffbc2f67453dbfa99bce9a32", "is_secret": false, "is_verified": false, - "line_number": 477, + "line_number": 476, "type": "Secret Keyword", "verified_result": null }, @@ -2364,7 +2364,7 @@ "hashed_secret": "a2166e0b243f4e192e14000a6aa8f46cde6bfc20", "is_secret": false, "is_verified": false, - "line_number": 843, + "line_number": 842, "type": "Secret Keyword", "verified_result": null }, @@ -2372,7 +2372,7 @@ "hashed_secret": "fa9beb99e4029ad5a6615399e7bbae21356086b3", "is_secret": false, "is_verified": false, - "line_number": 1260, + "line_number": 1259, "type": "Basic Auth Credentials", "verified_result": null }, @@ -2380,7 +2380,7 @@ "hashed_secret": "fa9beb99e4029ad5a6615399e7bbae21356086b3", "is_secret": false, "is_verified": false, - "line_number": 1264, + "line_number": 1263, "type": "Secret Keyword", "verified_result": null } @@ -2594,7 +2594,7 @@ "hashed_secret": "29b8dca3de5ff27bcf8bd3b622adf9970f29381c", "is_secret": false, "is_verified": false, - "line_number": 469, + "line_number": 468, "type": "Secret Keyword", "verified_result": null }, @@ -2602,7 +2602,7 @@ "hashed_secret": "d62bd0a3fef6be1bb3bb9078b323d87ac9f37f75", "is_secret": false, "is_verified": false, - "line_number": 572, + "line_number": 570, "type": "Secret Keyword", "verified_result": null } @@ -2612,7 +2612,7 @@ "hashed_secret": "e42972502db194ed2d6521cd5acf594dc1e1c503", "is_secret": false, "is_verified": false, - "line_number": 61, + "line_number": 60, "type": "Secret Keyword", "verified_result": null }, @@ -2620,7 +2620,7 @@ "hashed_secret": "cfac92b56e517a2186d60c9f812878a82a8c210c", "is_secret": false, "is_verified": false, - "line_number": 104, + "line_number": 103, "type": "Secret Keyword", "verified_result": null }, @@ -2628,7 +2628,7 @@ "hashed_secret": "2736fab291f04e69b62d490c3c09361f5b82461a", "is_secret": false, "is_verified": false, - "line_number": 135, + "line_number": 134, "type": "Secret Keyword", "verified_result": null }, @@ -2636,7 +2636,7 @@ "hashed_secret": "9b466094ec991a03cb95c489c19c4d75635f0ae5", "is_secret": false, "is_verified": false, - "line_number": 137, + "line_number": 136, "type": "Secret Keyword", "verified_result": null } @@ -4862,7 +4862,7 @@ "hashed_secret": "ff37a98a9963d347e9749a5c1b3936a4a245a6ff", "is_secret": false, "is_verified": false, - "line_number": 2414, + "line_number": 2413, "type": "Secret Keyword", "verified_result": null } @@ -8296,7 +8296,7 @@ "hashed_secret": "34e587c8f9ba011db386d719d66ffe3cfaea5447", "is_secret": false, "is_verified": false, - "line_number": 372, + "line_number": 421, "type": "Secret Keyword", "verified_result": null }, @@ -8304,7 +8304,7 @@ "hashed_secret": "a0f4ea7d91495df92bbac2e2149dfb850fe81396", "is_secret": false, "is_verified": false, - "line_number": 391, + "line_number": 440, "type": "Secret Keyword", "verified_result": null }, @@ -8312,7 +8312,7 @@ "hashed_secret": "920a25ef686c4f7ca6ad23dd109d3ad653161832", "is_secret": false, "is_verified": false, - "line_number": 640, + "line_number": 689, "type": "Secret Keyword", "verified_result": null }, @@ -8320,7 +8320,7 @@ "hashed_secret": "a62f2225bf70bfaccbc7f1ef2a397836717377de", "is_secret": false, "is_verified": false, - "line_number": 716, + "line_number": 765, "type": "Secret Keyword", "verified_result": null }, @@ -8328,7 +8328,7 @@ "hashed_secret": "355e7ab792a8403301eb0732bab9d2b3950ac048", "is_secret": false, "is_verified": false, - "line_number": 719, + "line_number": 768, "type": "Secret Keyword", "verified_result": null } @@ -8338,7 +8338,7 @@ "hashed_secret": "e582429052cdb908833560f6f7582d232de37c4d", "is_secret": false, "is_verified": false, - "line_number": 58, + "line_number": 62, "type": "Base64 High Entropy String", "verified_result": null }, @@ -8346,7 +8346,7 @@ "hashed_secret": "fe1bae27cb7c1fb823f496f286e78f1d2ae87734", "is_secret": false, "is_verified": false, - "line_number": 352, + "line_number": 366, "type": "Secret Keyword", "verified_result": null }, @@ -8354,7 +8354,7 @@ "hashed_secret": "72cb70dbbafe97e5ea13ad88acd65d08389439b0", "is_secret": false, "is_verified": false, - "line_number": 475, + "line_number": 489, "type": "Secret Keyword", "verified_result": null } @@ -8542,7 +8542,7 @@ "hashed_secret": "920a25ef686c4f7ca6ad23dd109d3ad653161832", "is_secret": false, "is_verified": false, - "line_number": 292, + "line_number": 331, "type": "Secret Keyword", "verified_result": null }, @@ -8550,7 +8550,7 @@ "hashed_secret": "48004e013423b89217e65eca07df9574fcd092a6", "is_secret": false, "is_verified": false, - "line_number": 359, + "line_number": 441, "type": "Secret Keyword", "verified_result": null } @@ -9298,7 +9298,7 @@ "hashed_secret": "00942f4668670f34c5943cf52c7ef3139fe2b8d6", "is_secret": false, "is_verified": false, - "line_number": 1269, + "line_number": 1313, "type": "Secret Keyword", "verified_result": null }, @@ -9306,7 +9306,7 @@ "hashed_secret": "233243ef95e736679cb1d5664a4c71ba89c10664", "is_secret": false, "is_verified": false, - "line_number": 1461, + "line_number": 1505, "type": "Secret Keyword", "verified_result": null }, @@ -9314,7 +9314,7 @@ "hashed_secret": "72cb70dbbafe97e5ea13ad88acd65d08389439b0", "is_secret": false, "is_verified": false, - "line_number": 1929, + "line_number": 1973, "type": "Secret Keyword", "verified_result": null } diff --git a/charts/mcp-stack/README.md b/charts/mcp-stack/README.md index 06bff37333..1c567fce82 100644 --- a/charts/mcp-stack/README.md +++ b/charts/mcp-stack/README.md @@ -814,7 +814,6 @@ For detailed guidance on resource limits and process management, see `docs/docs/ | mcpContextForge.secret.PROXY_USER_HEADER | string | `"X-Authenticated-User"` | | | mcpContextForge.secret.OAUTH_REQUEST_TIMEOUT | string | `"30"` | | | mcpContextForge.secret.OAUTH_MAX_RETRIES | string | `"3"` | | -| mcpContextForge.secret.OAUTH_DEFAULT_TIMEOUT | string | `"3600"` | | | mcpContextForge.secret.DCR_ENABLED | string | `"true"` | | | mcpContextForge.secret.DCR_AUTO_REGISTER_ON_MISSING_CREDENTIALS | string | `"true"` | | | mcpContextForge.secret.DCR_DEFAULT_SCOPES | string | `"[\"mcp:read\"]"` | | diff --git a/charts/mcp-stack/values.schema.json b/charts/mcp-stack/values.schema.json index d31203b7e6..c4a38c4bca 100644 --- a/charts/mcp-stack/values.schema.json +++ b/charts/mcp-stack/values.schema.json @@ -2004,9 +2004,6 @@ "MIN_SECRET_LENGTH": { "type": "string" }, - "OAUTH_DEFAULT_TIMEOUT": { - "type": "string" - }, "OAUTH_DISCOVERY_ENABLED": { "type": "string" }, diff --git a/charts/mcp-stack/values.yaml b/charts/mcp-stack/values.yaml index 495de8f1db..aa0677ca0b 100644 --- a/charts/mcp-stack/values.yaml +++ b/charts/mcp-stack/values.yaml @@ -880,7 +880,6 @@ mcpContextForge: # ─ OAuth Configuration ─ OAUTH_REQUEST_TIMEOUT: "30" # OAuth request timeout in seconds OAUTH_MAX_RETRIES: "3" # maximum retries for OAuth token requests - OAUTH_DEFAULT_TIMEOUT: "3600" # default OAuth token timeout in seconds # ─ OAuth Dynamic Client Registration (DCR) & PKCE ─ DCR_ENABLED: "true" # enable Dynamic Client Registration (RFC 7591) diff --git a/docs/docs/architecture/oauth-authorization-code-ui-design.md b/docs/docs/architecture/oauth-authorization-code-ui-design.md index 3a4336ce25..8101854faa 100644 --- a/docs/docs/architecture/oauth-authorization-code-ui-design.md +++ b/docs/docs/architecture/oauth-authorization-code-ui-design.md @@ -233,7 +233,6 @@ For user authentication and RBAC configuration, see [RBAC Configuration](../mana | --- | --- | --- | | `OAUTH_REQUEST_TIMEOUT` | `30` | Timeout for OAuth HTTP requests | | `OAUTH_MAX_RETRIES` | `3` | Retry count for token exchanges | -| `OAUTH_DEFAULT_TIMEOUT` | `3600` | Default `expires_in` when provider omits it | | `AUTH_ENCRYPTION_SECRET` | `my-test-salt` | Encrypts OAuth tokens and signs state | | `CACHE_TYPE` | `database` | `redis`, `database`, `memory`, or `none` | | `REDIS_URL` | `redis://localhost:6379` | Required when `CACHE_TYPE=redis` | diff --git a/docs/docs/manage/configuration.md b/docs/docs/manage/configuration.md index dfccfe3e4e..2fdcd9399f 100644 --- a/docs/docs/manage/configuration.md +++ b/docs/docs/manage/configuration.md @@ -182,7 +182,6 @@ ContextForge supports multiple database backends with full feature parity across | `AUTH_ENCRYPTION_SECRET` | Passphrase used to derive AES key for encrypting tool auth headers | `my-test-salt` | string | | `OAUTH_REQUEST_TIMEOUT` | OAuth request timeout in seconds | `30` | int > 0 | | `OAUTH_MAX_RETRIES` | Maximum retries for OAuth token requests | `3` | int > 0 | -| `OAUTH_DEFAULT_TIMEOUT` | Default OAuth token timeout in seconds | `3600` | int > 0 | | `INSECURE_ALLOW_QUERYPARAM_AUTH` | Enable query parameter authentication for gateways (see security warning) | `false` | bool | | `INSECURE_QUERYPARAM_AUTH_ALLOWED_HOSTS` | JSON array of hosts allowed to use query param auth | `[]` | JSON array | diff --git a/docs/docs/manage/oauth-troubleshooting.md b/docs/docs/manage/oauth-troubleshooting.md index 2553652f33..4bc9e7d4b2 100644 --- a/docs/docs/manage/oauth-troubleshooting.md +++ b/docs/docs/manage/oauth-troubleshooting.md @@ -459,7 +459,6 @@ REDIS_RETRY_INTERVAL_MS=2000 # OAuth settings OAUTH_REQUEST_TIMEOUT=30 OAUTH_MAX_RETRIES=3 -OAUTH_DEFAULT_TIMEOUT=3600 # PKCE settings OAUTH_DISCOVERY_ENABLED=true @@ -509,7 +508,6 @@ DCR_METADATA_CACHE_TTL=3600 | `AUTH_ENCRYPTION_SECRET` | `my-test-salt` | Secret for HMAC signing states (change in production!) | | `OAUTH_REQUEST_TIMEOUT` | `30` | Timeout for OAuth requests (seconds) | | `OAUTH_MAX_RETRIES` | `3` | Max retries for token requests | -| `OAUTH_DEFAULT_TIMEOUT` | `3600` | Default token expiration (seconds) | | `LOG_LEVEL` | `INFO` | Set to `DEBUG` for troubleshooting | --- diff --git a/docs/docs/manage/oauth.md b/docs/docs/manage/oauth.md index 54b837b6e6..de9a4703cf 100644 --- a/docs/docs/manage/oauth.md +++ b/docs/docs/manage/oauth.md @@ -55,7 +55,6 @@ See the flow details and security model in the architecture docs. # OAuth HTTP behavior OAUTH_REQUEST_TIMEOUT=30 # Seconds OAUTH_MAX_RETRIES=3 # Retries for transient failures -OAUTH_DEFAULT_TIMEOUT=3600 # Default OAuth token timeout in seconds # Secret encryption for stored OAuth client secrets (and tokens if enabled) AUTH_ENCRYPTION_SECRET= diff --git a/mcpgateway/config.py b/mcpgateway/config.py index 7ab6ec1ae4..6b0be3e20a 100644 --- a/mcpgateway/config.py +++ b/mcpgateway/config.py @@ -644,7 +644,6 @@ class Settings(BaseSettings): # OAuth Configuration oauth_request_timeout: int = Field(default=30, description="OAuth request timeout in seconds") oauth_max_retries: int = Field(default=3, description="Maximum retries for OAuth token requests") - oauth_default_timeout: int = Field(default=3600, description="Default OAuth token timeout in seconds") # =================================== # Dynamic Client Registration (DCR) - Client Mode diff --git a/mcpgateway/services/oauth_manager.py b/mcpgateway/services/oauth_manager.py index 91f790ec9a..84dca18b7c 100644 --- a/mcpgateway/services/oauth_manager.py +++ b/mcpgateway/services/oauth_manager.py @@ -785,7 +785,7 @@ async def complete_authorization_code_flow( app_user_email=app_user_email, # User from state access_token=token_response["access_token"], refresh_token=token_response.get("refresh_token"), - expires_in=token_response.get("expires_in", self.settings.oauth_default_timeout), + expires_in=parse_expires_in(token_response), scopes=token_response.get("scope", "").split(), ) @@ -1660,3 +1660,68 @@ def __init__(self, message: str, *, server_id: str = "") -> None: """ super().__init__(message) self.server_id = server_id + + +def parse_expires_in(token_response: Dict[str, Any]) -> Optional[int]: + """Parse and validate the ``expires_in`` field from an OAuth token response. + + RFC 6749 §5.1 marks ``expires_in`` as RECOMMENDED (not REQUIRED). When the + field is absent or null, the gateway records ``expires_at`` as ``None`` and + the token is treated as having no known local expiry, subject to the + stale-token cleanup policy in + :meth:`mcpgateway.services.token_storage_service.TokenStorageService.cleanup_expired_tokens`. + + Args: + token_response: Raw OAuth token response dict from the provider. + + Returns: + ``int`` lifetime in seconds when the provider supplied a non-negative + integer (or numeric string convertible to one), or ``None`` when the + field is absent or explicitly null. + + Raises: + OAuthError: If ``expires_in`` is present but malformed (negative, + non-numeric, or a non-scalar type). + + Examples: + >>> parse_expires_in({"expires_in": 3600}) + 3600 + >>> parse_expires_in({"expires_in": "3600"}) + 3600 + >>> parse_expires_in({"expires_in": 0}) + 0 + >>> parse_expires_in({}) is None + True + >>> parse_expires_in({"expires_in": None}) is None + True + >>> try: + ... parse_expires_in({"expires_in": -1}) + ... except OAuthError as exc: + ... "negative" in str(exc) + True + >>> try: + ... parse_expires_in({"expires_in": "garbage"}) + ... except OAuthError as exc: + ... "Invalid expires_in" in str(exc) + True + """ + raw = token_response.get("expires_in") + if raw is None: + return None + # Reject bools (True/False are int subclasses in Python) and any non-scalar types. + if isinstance(raw, bool) or not isinstance(raw, (int, float, str)): + raise OAuthError(f"Invalid expires_in from OAuth provider: {raw!r}") + # Sign-check the original numeric BEFORE int() truncation, otherwise int(-0.5) == 0 + # would bypass the negative check. + if isinstance(raw, (int, float)) and raw < 0: + raise OAuthError(f"Invalid expires_in from OAuth provider (negative): {raw}") + # RFC 6749 §5.1 specifies integer seconds; reject non-integral floats explicitly. + if isinstance(raw, float) and not raw.is_integer(): + raise OAuthError(f"Invalid expires_in from OAuth provider (non-integer): {raw}") + try: + value = int(raw) + except (TypeError, ValueError) as exc: + raise OAuthError(f"Invalid expires_in from OAuth provider: {raw!r}") from exc + if value < 0: + raise OAuthError(f"Invalid expires_in from OAuth provider (negative): {value}") + return value diff --git a/mcpgateway/services/token_storage_service.py b/mcpgateway/services/token_storage_service.py index d4ad2a6b03..2efff3ab9d 100644 --- a/mcpgateway/services/token_storage_service.py +++ b/mcpgateway/services/token_storage_service.py @@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional # Third-Party -from sqlalchemy import delete, select +from sqlalchemy import and_, delete, or_, select from sqlalchemy.orm import Session # First-Party @@ -29,6 +29,48 @@ logger = logging.getLogger(__name__) +def _preserve_prior_ttl(token_record: OAuthToken) -> Optional[int]: + """Compute the token's prior TTL in seconds, or ``None`` if not derivable. + + Used when an OAuth refresh response omits ``expires_in`` but the token + previously had a finite lifetime - the gateway preserves the original + issuance TTL by computing ``expires_at - updated_at`` from the existing + record. Returns ``None`` when either timestamp is missing or the difference + is non-positive (clock skew or already-expired records). + + Args: + token_record: Existing OAuth token row, before the refresh applies. + + Returns: + Positive integer seconds of prior TTL, or ``None``. + + Examples: + >>> from types import SimpleNamespace + >>> from datetime import datetime, timedelta, timezone + >>> issued = datetime(2026, 1, 1, tzinfo=timezone.utc) + >>> rec = SimpleNamespace(expires_at=issued + timedelta(hours=1), updated_at=issued) + >>> _preserve_prior_ttl(rec) + 3600 + >>> _preserve_prior_ttl(SimpleNamespace(expires_at=None, updated_at=issued)) is None + True + >>> _preserve_prior_ttl(SimpleNamespace(expires_at=issued, updated_at=issued + timedelta(hours=1))) is None + True + """ + prev_expires_at = token_record.expires_at + prev_updated_at = token_record.updated_at + if prev_expires_at is None or prev_updated_at is None: + return None + # Normalize naive timestamps to UTC for the subtraction. + if prev_expires_at.tzinfo is None: + prev_expires_at = prev_expires_at.replace(tzinfo=timezone.utc) + if prev_updated_at.tzinfo is None: + prev_updated_at = prev_updated_at.replace(tzinfo=timezone.utc) + prev_ttl = int((prev_expires_at - prev_updated_at).total_seconds()) + if prev_ttl <= 0: + return None + return prev_ttl + + class TokenStorageService: """Manages OAuth token storage and retrieval. @@ -74,7 +116,7 @@ def __init__(self, db: Session): logger.warning("OAuth encryption not available, using plain text storage") self.encryption = None - async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, access_token: str, refresh_token: Optional[str], expires_in: int, scopes: List[str]) -> OAuthToken: + async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, access_token: str, refresh_token: Optional[str], expires_in: Optional[int], scopes: List[str]) -> OAuthToken: """Store OAuth tokens for a gateway-user combination. Args: @@ -83,7 +125,7 @@ async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, app_user_email: ContextForge user email (required) access_token: Access token from OAuth provider refresh_token: Refresh token from OAuth provider (optional) - expires_in: Token expiration time in seconds + expires_in: Token expiration time in seconds, or None if the provider does not specify expiration scopes: List of OAuth scopes granted Returns: @@ -102,8 +144,15 @@ async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, if refresh_token: encrypted_refresh = await self.encryption.encrypt_secret_async(refresh_token) - # Calculate expiration - expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in)) + # Calculate expiration (None if provider does not specify expires_in) + if expires_in is not None: + expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in) + else: + logger.info( + "No expires_in from OAuth provider for gateway %s; token will not auto-expire", + SecurityValidator.sanitize_log_message(gateway_id), + ) + expires_at = None # Create or update token record - now scoped by app_user_email token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() @@ -293,7 +342,7 @@ def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None: # Use OAuthManager to refresh the token # First-Party - from mcpgateway.services.oauth_manager import OAuthManager # pylint: disable=import-outside-toplevel + from mcpgateway.services.oauth_manager import OAuthManager, parse_expires_in # pylint: disable=import-outside-toplevel oauth_manager = OAuthManager() @@ -309,7 +358,9 @@ def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None: # Update stored tokens with new values new_access_token = token_response["access_token"] new_refresh_token = token_response.get("refresh_token", refresh_token) # Some providers return new refresh token - expires_in = token_response.get("expires_in", 3600) + # Reuse the same parsing as the initial-auth path so refresh and + # callback flows agree on what "missing expires_in" means. + expires_in = parse_expires_in(token_response) # Encrypt new tokens if encryption is available encrypted_access = new_access_token @@ -321,8 +372,30 @@ def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None: # Update the token record token_record.access_token = encrypted_access token_record.refresh_token = encrypted_refresh - token_record.expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in)) - token_record.updated_at = datetime.now(timezone.utc) + now = datetime.now(timezone.utc) + if expires_in is not None: + token_record.expires_at = now + timedelta(seconds=expires_in) + else: + # Refresh response omitted expires_in. If the token previously had a finite + # expiry, preserve the prior TTL (expires_at - updated_at) so proactive + # refresh keeps working - clearing it outright would cause _is_token_expired + # to return False forever and stop the refresh loop. If there was no prior + # expiry, leave it as None (provider-level "no known lifetime"). + preserved_ttl = _preserve_prior_ttl(token_record) + if preserved_ttl is not None: + logger.info( + "No expires_in on refresh response for gateway %s; preserving prior TTL of %d seconds", + SecurityValidator.sanitize_log_message(token_record.gateway_id), + preserved_ttl, + ) + token_record.expires_at = now + timedelta(seconds=preserved_ttl) + else: + logger.info( + "No expires_in on refresh response for gateway %s; no prior TTL to preserve", + SecurityValidator.sanitize_log_message(token_record.gateway_id), + ) + token_record.expires_at = None + token_record.updated_at = now self.db.commit() logger.info(f"Successfully refreshed token for gateway {token_record.gateway_id}, user {token_record.app_user_email}") @@ -341,6 +414,14 @@ def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None: def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 300) -> bool: """Check if token is expired or near expiration. + Tokens with ``expires_at IS NULL`` are returned as non-expired by + design: when the OAuth provider omits ``expires_in`` (RFC 6749 §5.1 + marks it RECOMMENDED, not REQUIRED — see e.g. GitHub OAuth Apps), + the gateway has no local lifetime to check against. Stale-token + accumulation is bounded by + :meth:`cleanup_expired_tokens`, which ages out NULL-expiry rows + once ``created_at`` exceeds ``max_age_days``. + Args: token_record: OAuth token record to check threshold_seconds: Seconds before expiry to consider token expired @@ -366,6 +447,7 @@ def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 3 False """ if not token_record.expires_at: + # No provider-supplied lifetime; treat as non-expired (see contract above). return False expires_at = token_record.expires_at if expires_at.tzinfo is None: @@ -467,14 +549,26 @@ async def revoke_user_tokens(self, gateway_id: str, app_user_email: str) -> bool return False async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int: - """Clean up expired OAuth tokens older than specified days. - - Uses a single SQL DELETE statement instead of loading tokens into memory - and deleting them one by one. This is more efficient and avoids memory - issues when many tokens expire at once. + """Clean up stale OAuth tokens older than ``max_age_days``. + + Two cohorts are deleted in a single SQL ``DELETE`` so the table doesn't + accumulate dead rows: + + 1. Tokens whose ``expires_at`` is older than the cutoff (the original + "expired more than N days ago" behaviour). + 2. Tokens with ``expires_at IS NULL`` (provider omitted ``expires_in``) + whose ``updated_at`` is older than the cutoff. ``NULL < `` + evaluates to ``NULL`` in SQL three-valued logic, so without this + branch those rows would never age out. ``updated_at`` (rather than + ``created_at``) is the right freshness signal because + ``store_tokens`` advances it on re-authorization, so a recently + re-authorized token isn't deleted just because its original row was + old. Args: - max_age_days: Maximum age of tokens to keep + max_age_days: Maximum age of tokens to keep, measured from + ``expires_at`` for tokens with a known expiry and from + ``updated_at`` for tokens with no provider-supplied expiry. Returns: Number of tokens cleaned up @@ -491,17 +585,21 @@ async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int: try: cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=max_age_days) - result = self.db.execute(delete(OAuthToken).where(OAuthToken.expires_at < cutoff_date)) + stale_filter = or_( + OAuthToken.expires_at < cutoff_date, + and_(OAuthToken.expires_at.is_(None), OAuthToken.updated_at < cutoff_date), + ) + result = self.db.execute(delete(OAuthToken).where(stale_filter)) count = result.rowcount self.db.commit() if count > 0: - logger.info(f"Cleaned up {count} expired OAuth tokens") + logger.info("Cleaned up %d stale OAuth tokens", count) return count except Exception as e: self.db.rollback() - logger.error(f"Failed to cleanup expired tokens: {str(e)}") + logger.error("Failed to cleanup expired tokens: %s", e) return 0 diff --git a/tests/unit/mcpgateway/services/test_oauth_manager.py b/tests/unit/mcpgateway/services/test_oauth_manager.py index f43f006bdc..dd6a9b0ab8 100644 --- a/tests/unit/mcpgateway/services/test_oauth_manager.py +++ b/tests/unit/mcpgateway/services/test_oauth_manager.py @@ -11,7 +11,7 @@ import pytest # First-Party -from mcpgateway.services.oauth_manager import OAuthError, OAuthManager +from mcpgateway.services.oauth_manager import OAuthError, OAuthManager, parse_expires_in @pytest.fixture @@ -21,7 +21,6 @@ def oauth_manager(): auth_encryption_secret=MagicMock(get_secret_value=MagicMock(return_value="test-secret")), cache_type="memory", redis_url=None, - oauth_default_timeout=3600, ) mgr = OAuthManager(request_timeout=10, max_retries=1) return mgr @@ -46,6 +45,56 @@ def test_init_custom(): assert mgr.token_storage == "store" +# ---------- parse_expires_in ---------- + + +@pytest.mark.parametrize( + "value,expected", + [ + (3600, 3600), + ("3600", 3600), + (3600.0, 3600), + (0, 0), + ], +) +def test_parse_expires_in_valid(value, expected): + assert parse_expires_in({"expires_in": value}) == expected + + +def test_parse_expires_in_missing_returns_none(): + assert parse_expires_in({}) is None + + +def test_parse_expires_in_explicit_null_returns_none(): + assert parse_expires_in({"expires_in": None}) is None + + +@pytest.mark.parametrize("bad_value", [-1, -3600, "-1", -0.5, -3600.7]) +def test_parse_expires_in_negative_raises(bad_value): + """Negative numerics raise even when int() would silently truncate (-0.5 -> 0).""" + with pytest.raises(OAuthError, match="negative"): + parse_expires_in({"expires_in": bad_value}) + + +@pytest.mark.parametrize("bad_value", ["garbage", "3600s", ""]) +def test_parse_expires_in_garbage_string_raises(bad_value): + with pytest.raises(OAuthError, match="Invalid expires_in"): + parse_expires_in({"expires_in": bad_value}) + + +@pytest.mark.parametrize("bad_value", [3600.5, 0.5, 3600.7]) +def test_parse_expires_in_non_integer_float_raises(bad_value): + """RFC 6749 §5.1 specifies integer seconds; non-integer floats are rejected.""" + with pytest.raises(OAuthError, match="non-integer"): + parse_expires_in({"expires_in": bad_value}) + + +@pytest.mark.parametrize("bad_value", [True, False, [3600], {"seconds": 3600}, object()]) +def test_parse_expires_in_non_scalar_raises(bad_value): + with pytest.raises(OAuthError, match="Invalid expires_in"): + parse_expires_in({"expires_in": bad_value}) + + # ---------- _generate_pkce_params ---------- @@ -1177,6 +1226,7 @@ def test_create_authorization_url_with_pkce_no_scopes(oauth_manager): @pytest.mark.asyncio async def test_resolve_gateway_id_from_state_uses_legacy_fallback(oauth_manager): + # First-Party import mcpgateway.services.oauth_manager as om with ( @@ -1193,6 +1243,7 @@ async def test_resolve_gateway_id_from_state_uses_legacy_fallback(oauth_manager) @pytest.mark.asyncio async def test_resolve_gateway_id_from_state_skips_legacy_fallback_when_disabled(oauth_manager): + # First-Party import mcpgateway.services.oauth_manager as om with ( @@ -1221,6 +1272,7 @@ def test_oauth_error(): @pytest.mark.asyncio async def test_get_redis_client_already_initialized(): + # First-Party import mcpgateway.services.oauth_manager as om original_init = om._REDIS_INITIALIZED @@ -1237,6 +1289,7 @@ async def test_get_redis_client_already_initialized(): @pytest.mark.asyncio async def test_get_redis_client_no_redis(): + # First-Party import mcpgateway.services.oauth_manager as om original_init = om._REDIS_INITIALIZED diff --git a/tests/unit/mcpgateway/services/test_oauth_manager_pkce.py b/tests/unit/mcpgateway/services/test_oauth_manager_pkce.py index d2ce16eec3..c24319f025 100644 --- a/tests/unit/mcpgateway/services/test_oauth_manager_pkce.py +++ b/tests/unit/mcpgateway/services/test_oauth_manager_pkce.py @@ -7,12 +7,16 @@ Tests will FAIL until implementation is complete. """ +# Standard from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch +# Third-Party import httpx import pytest -from mcpgateway.services.oauth_manager import OAuthManager, OAuthError + +# First-Party +from mcpgateway.services.oauth_manager import OAuthError, OAuthManager class TestPKCEGeneration: @@ -71,6 +75,7 @@ def test_generate_pkce_params_is_unique(self): def test_generate_pkce_params_challenge_is_sha256_of_verifier(self): """Test that code_challenge is SHA256 hash of code_verifier.""" + # Standard import base64 import hashlib @@ -178,9 +183,12 @@ async def test_validate_and_retrieve_state_returns_code_verifier(self, monkeypat state = "test-state" # Mock in-memory state storage - from mcpgateway.services.oauth_manager import _oauth_states, _state_lock + # Standard from datetime import datetime, timedelta, timezone + # First-Party + from mcpgateway.services.oauth_manager import _oauth_states, _state_lock + state_key = f"oauth:state:{gateway_id}:{state}" expires_at = datetime.now(timezone.utc) + timedelta(seconds=300) @@ -203,9 +211,12 @@ async def test_validate_and_retrieve_state_returns_none_if_expired(self): gateway_id = "test-gateway-123" state = "test-state" - from mcpgateway.services.oauth_manager import _oauth_states, _state_lock + # Standard from datetime import datetime, timedelta, timezone + # First-Party + from mcpgateway.services.oauth_manager import _oauth_states, _state_lock + state_key = f"oauth:state:{gateway_id}:{state}" expires_at = datetime.now(timezone.utc) - timedelta(seconds=60) # Expired @@ -224,9 +235,12 @@ async def test_validate_and_retrieve_state_single_use(self, monkeypatch): gateway_id = "test-gateway-123" state = "test-state" - from mcpgateway.services.oauth_manager import _oauth_states, _state_lock + # Standard from datetime import datetime, timedelta, timezone + # First-Party + from mcpgateway.services.oauth_manager import _oauth_states, _state_lock + state_key = f"oauth:state:{gateway_id}:{state}" expires_at = datetime.now(timezone.utc) + timedelta(seconds=300) @@ -517,6 +531,7 @@ def _make_response(*, status_code=200, headers=None, text="", json_data=None, js class TestOAuthManagerRedisClient: @pytest.mark.asyncio async def test_get_redis_client_cached(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om monkeypatch.setattr(om, "_REDIS_INITIALIZED", False) @@ -541,6 +556,7 @@ async def test_get_redis_client_cached(self, monkeypatch): @pytest.mark.asyncio async def test_get_redis_client_error_falls_back(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om monkeypatch.setattr(om, "_REDIS_INITIALIZED", False) @@ -766,6 +782,7 @@ async def test_store_authorization_state_redis(self, monkeypatch): @pytest.mark.asyncio async def test_store_authorization_state_redis_failure_falls_back(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -781,6 +798,7 @@ async def test_store_authorization_state_redis_failure_falls_back(self, monkeypa @pytest.mark.asyncio async def test_store_authorization_state_memory_cleanup(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -804,9 +822,7 @@ async def test_store_authorization_state_memory_cleanup(self, monkeypatch): async def test_validate_authorization_state_redis(self, monkeypatch): manager = OAuthManager() redis = AsyncMock() - redis.getdel = AsyncMock( - return_value=b'{"state":"s","gateway_id":"gw","code_verifier":"v","expires_at":"2099-01-01T00:00:00","used":false}' - ) + redis.getdel = AsyncMock(return_value=b'{"state":"s","gateway_id":"gw","code_verifier":"v","expires_at":"2099-01-01T00:00:00","used":false}') monkeypatch.setattr("mcpgateway.services.oauth_manager.get_settings", lambda: SimpleNamespace(cache_type="redis", redis_url="redis://localhost")) monkeypatch.setattr("mcpgateway.services.oauth_manager._get_redis_client", AsyncMock(return_value=redis)) @@ -821,13 +837,12 @@ async def test_validate_authorization_state_redis_missing_and_used(self, monkeyp monkeypatch.setattr("mcpgateway.services.oauth_manager._get_redis_client", AsyncMock(return_value=redis)) assert await manager._validate_authorization_state("gw", "missing") is False - redis.getdel = AsyncMock( - return_value=b'{"state":"s","gateway_id":"gw","code_verifier":"v","expires_at":"2099-01-01T00:00:00","used":true}' - ) + redis.getdel = AsyncMock(return_value=b'{"state":"s","gateway_id":"gw","code_verifier":"v","expires_at":"2099-01-01T00:00:00","used":true}') assert await manager._validate_authorization_state("gw", "s") is False @pytest.mark.asyncio async def test_validate_authorization_state_in_memory_expired(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -846,9 +861,7 @@ async def test_validate_authorization_state_in_memory_expired(self, monkeypatch) async def test_validate_and_retrieve_state_redis(self, monkeypatch): manager = OAuthManager() redis = AsyncMock() - redis.getdel = AsyncMock( - return_value=b'{"state":"s","gateway_id":"gw","code_verifier":"v","expires_at":"2099-01-01T00:00:00","used":false}' - ) + redis.getdel = AsyncMock(return_value=b'{"state":"s","gateway_id":"gw","code_verifier":"v","expires_at":"2099-01-01T00:00:00","used":false}') monkeypatch.setattr("mcpgateway.services.oauth_manager.get_settings", lambda: SimpleNamespace(cache_type="redis", redis_url="redis://localhost")) monkeypatch.setattr("mcpgateway.services.oauth_manager._get_redis_client", AsyncMock(return_value=redis)) @@ -972,6 +985,7 @@ class TestGetRedisClientNonRedis: @pytest.mark.asyncio async def test_non_redis_cache_type(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om monkeypatch.setattr(om, "_REDIS_INITIALIZED", False) @@ -985,6 +999,7 @@ async def test_non_redis_cache_type(self, monkeypatch): @pytest.mark.asyncio async def test_redis_returns_none_client(self, monkeypatch): """Redis factory returns None (line 68->76 partial: client is None).""" + # First-Party import mcpgateway.services.oauth_manager as om monkeypatch.setattr(om, "_REDIS_INITIALIZED", False) @@ -1133,6 +1148,7 @@ class TestInitiateFlowNoStorage: @pytest.mark.asyncio async def test_no_token_storage_skips_store(self): + # Third-Party from pydantic import SecretStr with patch("mcpgateway.services.oauth_manager.get_settings") as mock_get_settings: @@ -1155,14 +1171,15 @@ class TestCompleteFlowHMACBranches: @pytest.mark.asyncio async def test_invalid_hmac_signature_falls_back(self): """Invalid HMAC triggers fallback to legacy format (line 570 via except).""" + # Standard import base64 + # Third-Party from pydantic import SecretStr with patch("mcpgateway.services.oauth_manager.get_settings") as mock_settings: settings = MagicMock() settings.auth_encryption_secret = SecretStr("test-secret") - settings.oauth_default_timeout = 3600 mock_settings.return_value = settings manager = OAuthManager(token_storage=None) @@ -1172,28 +1189,34 @@ async def test_invalid_hmac_signature_falls_back(self): bad_signature = b"\x00" * 32 state = base64.urlsafe_b64encode(state_bytes + bad_signature).decode() - with patch.object(manager, "_validate_and_retrieve_state", return_value={"code_verifier": "v"}), patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok"}), patch.object(manager, "_extract_user_id", return_value="user1"): + with ( + patch.object(manager, "_validate_and_retrieve_state", return_value={"code_verifier": "v"}), + patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok"}), + patch.object(manager, "_extract_user_id", return_value="user1"), + ): result = await manager.complete_authorization_code_flow("gw1", "code", state, {"client_id": "cid"}) assert result["success"] is True @pytest.mark.asyncio async def test_gateway_mismatch_is_rejected(self): """Legacy state with mismatched gateway_id is rejected.""" + # Standard import base64 import hashlib import hmac as hmac_mod + # Third-Party from pydantic import SecretStr with patch("mcpgateway.services.oauth_manager.get_settings") as mock_settings: settings = MagicMock() settings.auth_encryption_secret = SecretStr("test-secret") - settings.oauth_default_timeout = 3600 mock_settings.return_value = settings manager = OAuthManager(token_storage=None) # Create valid HMAC state but with different gateway_id + # Third-Party import orjson state_data = {"gateway_id": "other_gw", "app_user_email": "user@test.com", "nonce": "abc", "timestamp": "2025-01-01T00:00:00"} @@ -1202,15 +1225,21 @@ async def test_gateway_mismatch_is_rejected(self): sig = hmac_mod.new(secret_key, state_bytes, hashlib.sha256).digest() state = base64.urlsafe_b64encode(state_bytes + sig).decode() - with patch.object(manager, "_validate_and_retrieve_state", return_value={"code_verifier": "v"}), patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok"}), patch.object(manager, "_extract_user_id", return_value="user1"): + with ( + patch.object(manager, "_validate_and_retrieve_state", return_value={"code_verifier": "v"}), + patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok"}), + patch.object(manager, "_extract_user_id", return_value="user1"), + ): with pytest.raises(OAuthError, match="gateway mismatch"): await manager.complete_authorization_code_flow("gw1", "code", state, {"client_id": "cid"}) @pytest.mark.asyncio async def test_no_email_with_storage_raises(self): """Token storage present but no email raises OAuthError (line 595).""" + # Standard import base64 + # Third-Party from pydantic import SecretStr with patch("mcpgateway.services.oauth_manager.get_settings") as mock_settings: @@ -1224,7 +1253,11 @@ async def test_no_email_with_storage_raises(self): # Create invalid state that triggers fallback → app_user_email = None state = base64.urlsafe_b64encode(b"invalid_data_for_parsing").decode() - with patch.object(manager, "_validate_and_retrieve_state", return_value={"code_verifier": "v"}), patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok"}), patch.object(manager, "_extract_user_id", return_value="user1"): + with ( + patch.object(manager, "_validate_and_retrieve_state", return_value={"code_verifier": "v"}), + patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok"}), + patch.object(manager, "_extract_user_id", return_value="user1"), + ): with pytest.raises(OAuthError, match="User context required"): await manager.complete_authorization_code_flow("gw1", "code", state, {"client_id": "cid"}) @@ -1235,11 +1268,15 @@ async def test_complete_flow_uses_server_side_user_context_without_state_decodin manager = OAuthManager(token_storage=MagicMock()) manager.token_storage.store_tokens = AsyncMock(return_value=SimpleNamespace(expires_at=None)) - with patch.object( - manager, - "_validate_and_retrieve_state", - return_value={"code_verifier": "v", "app_user_email": "user@test.com"}, - ), patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok", "expires_in": 3600}), patch.object(manager, "_extract_user_id", return_value="user-1"): + with ( + patch.object( + manager, + "_validate_and_retrieve_state", + return_value={"code_verifier": "v", "app_user_email": "user@test.com"}, + ), + patch.object(manager, "_exchange_code_for_tokens", return_value={"access_token": "tok", "expires_in": 3600}), + patch.object(manager, "_extract_user_id", return_value="user-1"), + ): result = await manager.complete_authorization_code_flow( "gw1", "code", @@ -1328,9 +1365,12 @@ async def test_resolve_gateway_id_from_state_database_success(monkeypatch): @pytest.mark.asyncio async def test_resolve_gateway_id_from_state_cleans_expired_in_memory_entries(monkeypatch): - import mcpgateway.services.oauth_manager as om + # Standard from datetime import datetime, timedelta, timezone + # First-Party + import mcpgateway.services.oauth_manager as om + manager = OAuthManager() monkeypatch.setattr("mcpgateway.services.oauth_manager.get_settings", lambda: SimpleNamespace(cache_type="memory")) @@ -1389,6 +1429,7 @@ async def test_database_storage_success(self, monkeypatch): @pytest.mark.asyncio async def test_database_storage_failure_falls_back(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -1412,6 +1453,7 @@ class TestValidateAuthorizationStateRedisEdgeCases: @pytest.mark.asyncio async def test_redis_datetime_fallback_parse(self, monkeypatch): """Datetime fromisoformat fails, fallback to strptime (lines 743-745).""" + # Third-Party import orjson manager = OAuthManager() @@ -1429,6 +1471,7 @@ async def test_redis_datetime_fallback_parse(self, monkeypatch): @pytest.mark.asyncio async def test_redis_expired_state(self, monkeypatch): """Expired state in Redis (lines 753-754).""" + # Third-Party import orjson manager = OAuthManager() @@ -1444,6 +1487,7 @@ async def test_redis_expired_state(self, monkeypatch): @pytest.mark.asyncio async def test_redis_naive_datetime(self, monkeypatch): """Naive datetime in Redis requires UTC assumption (branch 747->752).""" + # Third-Party import orjson manager = OAuthManager() @@ -1460,6 +1504,7 @@ async def test_redis_naive_datetime(self, monkeypatch): @pytest.mark.asyncio async def test_redis_exception_falls_back_to_memory(self, monkeypatch): """Redis exception falls back (lines 763-764).""" + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -1479,6 +1524,7 @@ class TestValidateAuthorizationStateDatabasePath: @pytest.mark.asyncio async def test_database_valid_state(self, monkeypatch): + # Standard from datetime import datetime, timezone manager = OAuthManager() @@ -1515,6 +1561,7 @@ async def test_database_not_found(self, monkeypatch): @pytest.mark.asyncio async def test_database_expired_state(self, monkeypatch): + # Standard from datetime import datetime, timezone manager = OAuthManager() @@ -1536,6 +1583,7 @@ async def test_database_expired_state(self, monkeypatch): @pytest.mark.asyncio async def test_database_used_state(self, monkeypatch): + # Standard from datetime import datetime, timezone manager = OAuthManager() @@ -1556,6 +1604,7 @@ async def test_database_used_state(self, monkeypatch): @pytest.mark.asyncio async def test_database_naive_datetime(self, monkeypatch): + # Standard from datetime import datetime manager = OAuthManager() @@ -1576,6 +1625,7 @@ async def test_database_naive_datetime(self, monkeypatch): @pytest.mark.asyncio async def test_database_exception_falls_back(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -1591,6 +1641,7 @@ class TestValidateAuthorizationStateMemoryUsed: @pytest.mark.asyncio async def test_memory_used_state(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -1613,6 +1664,7 @@ class TestValidateAndRetrieveStateRedisEdgeCases: @pytest.mark.asyncio async def test_redis_datetime_fallback(self, monkeypatch): """Fallback datetime parse (lines 866-867).""" + # Third-Party import orjson manager = OAuthManager() @@ -1632,6 +1684,7 @@ async def test_redis_datetime_fallback(self, monkeypatch): @pytest.mark.asyncio async def test_redis_expired(self, monkeypatch): """Expired state returns None (line 873).""" + # Third-Party import orjson manager = OAuthManager() @@ -1647,6 +1700,7 @@ async def test_redis_expired(self, monkeypatch): @pytest.mark.asyncio async def test_redis_naive_datetime(self, monkeypatch): """Naive datetime requires UTC assumption (branch 869->872).""" + # Third-Party import orjson manager = OAuthManager() @@ -1664,6 +1718,7 @@ async def test_redis_naive_datetime(self, monkeypatch): @pytest.mark.asyncio async def test_redis_exception(self, monkeypatch): """Redis exception falls back (lines 876-877).""" + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -1682,6 +1737,7 @@ class TestValidateAndRetrieveStateDatabasePath: @pytest.mark.asyncio async def test_database_valid(self, monkeypatch): + # Standard from datetime import datetime, timezone manager = OAuthManager() @@ -1722,6 +1778,7 @@ async def test_database_not_found(self, monkeypatch): @pytest.mark.asyncio async def test_database_expired(self, monkeypatch): + # Standard from datetime import datetime, timezone manager = OAuthManager() @@ -1743,6 +1800,7 @@ async def test_database_expired(self, monkeypatch): @pytest.mark.asyncio async def test_database_used(self, monkeypatch): + # Standard from datetime import datetime, timezone manager = OAuthManager() @@ -1763,6 +1821,7 @@ async def test_database_used(self, monkeypatch): @pytest.mark.asyncio async def test_database_naive_datetime(self, monkeypatch): + # Standard from datetime import datetime manager = OAuthManager() @@ -1787,6 +1846,7 @@ async def test_database_naive_datetime(self, monkeypatch): @pytest.mark.asyncio async def test_database_exception(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -1802,6 +1862,7 @@ class TestValidateAndRetrieveStateMemoryNaiveDatetime: @pytest.mark.asyncio async def test_memory_naive_datetime(self, monkeypatch): + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -1825,14 +1886,26 @@ class TestCreateAuthUrlWithPKCEResource: def test_single_string_resource(self): manager = OAuthManager() - credentials = {"client_id": "cid", "authorization_url": "https://auth.example.com/authorize", "redirect_uri": "https://app.example.com/callback", "scopes": [], "resource": "https://api.example.com"} + credentials = { + "client_id": "cid", + "authorization_url": "https://auth.example.com/authorize", + "redirect_uri": "https://app.example.com/callback", + "scopes": [], + "resource": "https://api.example.com", + } url = manager._create_authorization_url_with_pkce(credentials, "state", "challenge", "S256") assert "resource=https" in url def test_list_resource(self): manager = OAuthManager() - credentials = {"client_id": "cid", "authorization_url": "https://auth.example.com/authorize", "redirect_uri": "https://app.example.com/callback", "scopes": [], "resource": ["https://api1.example.com", "https://api2.example.com"]} + credentials = { + "client_id": "cid", + "authorization_url": "https://auth.example.com/authorize", + "redirect_uri": "https://app.example.com/callback", + "scopes": [], + "resource": ["https://api1.example.com", "https://api2.example.com"], + } url = manager._create_authorization_url_with_pkce(credentials, "state", "challenge", "S256") assert "resource=" in url @@ -1878,7 +1951,13 @@ async def test_no_client_secret(self, monkeypatch): async def test_single_string_resource(self, monkeypatch): """Single string resource param (line 1065).""" manager = OAuthManager(max_retries=1) - credentials = {"client_id": "cid", "client_secret": "secret", "token_url": "https://auth.example.com/token", "redirect_uri": "https://app.example.com/callback", "resource": "https://api.example.com"} + credentials = { + "client_id": "cid", + "client_secret": "secret", + "token_url": "https://auth.example.com/token", + "redirect_uri": "https://app.example.com/callback", + "resource": "https://api.example.com", + } response = _make_response(json_data={"access_token": "tok"}, headers={"content-type": "application/json"}) client = AsyncMock() @@ -1894,7 +1973,12 @@ async def test_single_string_resource(self, monkeypatch): async def test_resource_list_with_empty(self, monkeypatch): """Resource list with falsy entry (branch 1061->1060).""" manager = OAuthManager(max_retries=1) - credentials = {"client_id": "cid", "token_url": "https://auth.example.com/token", "redirect_uri": "https://app.example.com/callback", "resource": ["https://api1.example.com", "", "https://api2.example.com"]} + credentials = { + "client_id": "cid", + "token_url": "https://auth.example.com/token", + "redirect_uri": "https://app.example.com/callback", + "resource": ["https://api1.example.com", "", "https://api2.example.com"], + } response = _make_response(json_data={"access_token": "tok"}, headers={"content-type": "application/json"}) client = AsyncMock() @@ -2164,6 +2248,7 @@ class TestRedisNoneFallthroughPaths: @pytest.mark.asyncio async def test_store_state_redis_none_falls_to_memory(self, monkeypatch): """_store_authorization_state: redis is None, falls through (branch 664->676).""" + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -2177,6 +2262,7 @@ async def test_store_state_redis_none_falls_to_memory(self, monkeypatch): @pytest.mark.asyncio async def test_validate_state_redis_none_falls_to_memory(self, monkeypatch): """_validate_authorization_state: redis is None, falls through (branch 728->767).""" + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -2189,6 +2275,7 @@ async def test_validate_state_redis_none_falls_to_memory(self, monkeypatch): @pytest.mark.asyncio async def test_validate_and_retrieve_redis_none_falls_to_memory(self, monkeypatch): """_validate_and_retrieve_state: redis is None, falls through (branch 854->880).""" + # First-Party import mcpgateway.services.oauth_manager as om manager = OAuthManager() @@ -2242,7 +2329,10 @@ async def test_missing_email_in_state_hard_errors(self): with self._patches(manager)[0], self._patches(manager)[1], self._patches(manager)[2]: with pytest.raises(OAuthError, match="User context required"): await manager.complete_authorization_code_flow( - "gw1", "code", "state-tok", {"client_id": "cid"}, + "gw1", + "code", + "state-tok", + {"client_id": "cid"}, ) manager.token_storage.store_tokens.assert_not_awaited() @@ -2251,11 +2341,16 @@ async def test_no_token_storage_succeeds_without_email(self): """When token_storage is None, missing email is acceptable (no binding risk).""" manager = OAuthManager(token_storage=None) - with patch.object(manager, "_validate_and_retrieve_state", return_value=self._state_without_email()), \ - patch.object(manager, "_exchange_code_for_tokens", return_value=self._token_response()), \ - patch.object(manager, "_extract_user_id", return_value="user-1"): + with ( + patch.object(manager, "_validate_and_retrieve_state", return_value=self._state_without_email()), + patch.object(manager, "_exchange_code_for_tokens", return_value=self._token_response()), + patch.object(manager, "_extract_user_id", return_value="user-1"), + ): result = await manager.complete_authorization_code_flow( - "gw1", "code", "state-tok", {"client_id": "cid"}, + "gw1", + "code", + "state-tok", + {"client_id": "cid"}, ) assert result["success"] is True assert result["expires_at"] is None @@ -2272,8 +2367,10 @@ class TestLegacyStateGatewayMismatch: @pytest.mark.asyncio async def test_base64_legacy_state_gateway_mismatch_raises(self): """Legacy base64 state with different gateway_id must raise OAuthError.""" + # Standard import base64 + # Third-Party import orjson manager = OAuthManager(token_storage=AsyncMock()) @@ -2290,7 +2387,10 @@ async def test_base64_legacy_state_gateway_mismatch_raises(self): ): with pytest.raises(OAuthError, match="State parameter gateway mismatch"): await manager.complete_authorization_code_flow( - "gw1", "code", state, {"client_id": "cid"}, + "gw1", + "code", + state, + {"client_id": "cid"}, ) @pytest.mark.asyncio @@ -2308,7 +2408,10 @@ async def test_underscore_legacy_state_gateway_mismatch_raises(self): ): with pytest.raises(OAuthError, match="State parameter gateway mismatch"): await manager.complete_authorization_code_flow( - "gw1", "code", state, {"client_id": "cid"}, + "gw1", + "code", + state, + {"client_id": "cid"}, ) @pytest.mark.asyncio @@ -2317,8 +2420,10 @@ async def test_matching_gateway_id_does_not_raise_mismatch(self): It should fall through to the missing-email guard instead. """ + # Standard import base64 + # Third-Party import orjson manager = OAuthManager(token_storage=AsyncMock()) @@ -2336,7 +2441,10 @@ async def test_matching_gateway_id_does_not_raise_mismatch(self): # Should NOT raise "gateway mismatch"; should raise "User context required" instead with pytest.raises(OAuthError, match="User context required"): await manager.complete_authorization_code_flow( - "gw1", "code", state, {"client_id": "cid"}, + "gw1", + "code", + state, + {"client_id": "cid"}, ) @@ -2349,8 +2457,10 @@ class TestLegacyStatePayloadStripsIdentity: def test_base64_payload_with_forged_email_is_stripped(self): """Crafted base64 legacy state with app_user_email must not return it.""" + # Standard import base64 + # Third-Party import orjson manager = OAuthManager() @@ -2367,8 +2477,10 @@ def test_base64_payload_with_forged_email_is_stripped(self): def test_base64_payload_returns_none_for_identity_only(self): """If payload has only identity fields and no gateway_id, return None.""" + # Standard import base64 + # Third-Party import orjson manager = OAuthManager() @@ -2391,8 +2503,10 @@ def test_gateway_suffix_format_has_no_email(self): @pytest.mark.asyncio async def test_complete_flow_ignores_forged_email_in_legacy_state(self): """complete_authorization_code_flow must not use email from unsigned legacy state.""" + # Standard import base64 + # Third-Party import orjson manager = OAuthManager(token_storage=MagicMock()) @@ -2405,20 +2519,27 @@ async def test_complete_flow_ignores_forged_email_in_legacy_state(self): fake_sig = b"\x00" * 32 state = base64.urlsafe_b64encode(payload_bytes + fake_sig).decode() - with patch.object( - manager, - "_validate_and_retrieve_state", - return_value={"code_verifier": "v"}, - ), patch.object( - manager, - "_exchange_code_for_tokens", - return_value={"access_token": "tok"}, - ), patch.object(manager, "_extract_user_id", return_value="user1"): + with ( + patch.object( + manager, + "_validate_and_retrieve_state", + return_value={"code_verifier": "v"}, + ), + patch.object( + manager, + "_exchange_code_for_tokens", + return_value={"access_token": "tok"}, + ), + patch.object(manager, "_extract_user_id", return_value="user1"), + ): # With token_storage set, missing app_user_email should raise, # proving the forged email was NOT accepted. with pytest.raises(OAuthError, match="User context required"): await manager.complete_authorization_code_flow( - "gw1", "code", state, {"client_id": "cid"}, + "gw1", + "code", + state, + {"client_id": "cid"}, ) diff --git a/tests/unit/mcpgateway/services/test_token_storage_service.py b/tests/unit/mcpgateway/services/test_token_storage_service.py index b7ef89609a..f0690ed8f9 100644 --- a/tests/unit/mcpgateway/services/test_token_storage_service.py +++ b/tests/unit/mcpgateway/services/test_token_storage_service.py @@ -156,6 +156,45 @@ async def test_store_tokens_exception(service, mock_db): mock_db.rollback.assert_called_once() +@pytest.mark.asyncio +async def test_store_tokens_no_expires_in_persists_null(service, mock_db, caplog): + mock_db.execute.return_value.scalar_one_or_none.return_value = None + # Standard + import logging + + with caplog.at_level(logging.INFO): + await service.store_tokens( + gateway_id="gw-1", + user_id="user-1", + app_user_email="user@test.com", + access_token="access123", + refresh_token="refresh123", + expires_in=None, + scopes=["read"], + ) + mock_db.add.assert_called_once() + new_record = mock_db.add.call_args.args[0] + assert new_record.expires_at is None + assert any("token will not auto-expire" in msg for msg in caplog.messages) + + +@pytest.mark.asyncio +async def test_store_tokens_update_existing_clears_expires_at_when_no_expires_in(service, mock_db): + existing = _make_token_record() + assert existing.expires_at is not None + mock_db.execute.return_value.scalar_one_or_none.return_value = existing + await service.store_tokens( + gateway_id="gw-1", + user_id="user-1", + app_user_email="user@test.com", + access_token="new_access", + refresh_token="new_refresh", + expires_in=None, + scopes=["read"], + ) + assert existing.expires_at is None + + # ---------- get_user_token ---------- @@ -299,6 +338,49 @@ async def test_refresh_success(service, mock_db): mock_db.commit.assert_called() +@pytest.mark.asyncio +async def test_refresh_without_expires_in_preserves_prior_ttl(service, mock_db): + """Refresh response missing expires_in: preserve the prior TTL so proactive refresh keeps working.""" + gw = MagicMock(oauth_config={"token_url": "https://token", "client_id": "cid"}, url="https://gw.com") + mock_db.query.return_value.filter.return_value.first.return_value = gw + mock_oauth_manager = MagicMock() + mock_oauth_manager.refresh_token = AsyncMock(return_value={"access_token": "new_access", "refresh_token": "new_refresh"}) + + # Token issued 100 seconds ago with a 1-hour TTL. + record = _make_token_record() + issued = datetime.now(tz=timezone.utc) - timedelta(seconds=100) + record.updated_at = issued + record.expires_at = issued + timedelta(seconds=3600) + + with patch("mcpgateway.services.oauth_manager.OAuthManager", return_value=mock_oauth_manager): + result = await service._refresh_access_token(record) + + assert result == "new_access" + # Prior TTL preserved: new expires_at should be ~3600s after the refresh moment (now). + assert record.expires_at is not None + delta = (record.expires_at - record.updated_at).total_seconds() + assert 3599 <= delta <= 3601 + + +@pytest.mark.asyncio +async def test_refresh_without_expires_in_no_prior_ttl_stays_none(service, mock_db): + """Refresh response missing expires_in AND no prior TTL: expires_at stays None.""" + gw = MagicMock(oauth_config={"token_url": "https://token", "client_id": "cid"}, url="https://gw.com") + mock_db.query.return_value.filter.return_value.first.return_value = gw + mock_oauth_manager = MagicMock() + mock_oauth_manager.refresh_token = AsyncMock(return_value={"access_token": "new_access", "refresh_token": "new_refresh"}) + + # Token had no prior expiry (e.g. GitHub OAuth Apps). + record = _make_token_record() + record.expires_at = None + + with patch("mcpgateway.services.oauth_manager.OAuthManager", return_value=mock_oauth_manager): + result = await service._refresh_access_token(record) + + assert result == "new_access" + assert record.expires_at is None + + @pytest.mark.asyncio async def test_refresh_success_with_resource_list(service, mock_db): gw = MagicMock(oauth_config={"token_url": "https://token", "client_id": "cid", "resource": ["https://api.example.com", "https://other.com"]}, url="https://gw.com") @@ -507,6 +589,20 @@ async def test_cleanup_expired_tokens_exception(service, mock_db): mock_db.rollback.assert_called_once() +@pytest.mark.asyncio +async def test_cleanup_expired_tokens_targets_null_expires_at(service, mock_db): + mock_db.execute.return_value.rowcount = 3 + await service.cleanup_expired_tokens(max_age_days=30) + + delete_stmt = mock_db.execute.call_args.args[0] + rendered = str(delete_stmt.compile(compile_kwargs={"literal_binds": True})).lower() + assert "expires_at is null" in rendered + # NULL-expires_at rows must be aged out by updated_at (re-auth advances it), + # not created_at (which would delete recently re-authorized tokens). + assert "updated_at" in rendered + assert "created_at" not in rendered + + # ---------- token_type validation in get_user_token ---------- diff --git a/tests/unit/mcpgateway/test_oauth_manager.py b/tests/unit/mcpgateway/test_oauth_manager.py index 2353690973..aed4eba17f 100644 --- a/tests/unit/mcpgateway/test_oauth_manager.py +++ b/tests/unit/mcpgateway/test_oauth_manager.py @@ -533,6 +533,50 @@ async def test_complete_authorization_code_flow_success(self): mock_extract_user.assert_called_once_with(token_response, credentials) mock_store_tokens.assert_called_once() + @pytest.mark.asyncio + async def test_complete_authorization_code_flow_omitted_expires_in_passes_none(self): + """Provider responses without ``expires_in`` flow through as ``expires_in=None`` to store_tokens.""" + # Standard + import base64 + import hashlib + import hmac + import json + from unittest.mock import patch + + with patch("mcpgateway.services.oauth_manager.get_settings") as mock_get_settings: + mock_settings = Mock() + mock_settings.auth_encryption_secret = SecretStr("test-secret-key") + mock_get_settings.return_value = mock_settings + + mock_token_storage = Mock() + manager = OAuthManager(token_storage=mock_token_storage) + + gateway_id = "gateway123" + code = "auth_code_123" + # Standard + from datetime import datetime, timezone + + state_data = {"gateway_id": gateway_id, "app_user_email": "test@example.com", "nonce": "state-no-exp", "timestamp": datetime.now(timezone.utc).isoformat()} + state_bytes = json.dumps(state_data, separators=(",", ":")).encode() + secret_key = b"test-secret-key" + signature = hmac.new(secret_key, state_bytes, hashlib.sha256).digest() + state = base64.urlsafe_b64encode(state_bytes + signature).decode() + + credentials = {"client_id": "test_client"} + # Provider response intentionally omits expires_in (e.g. GitHub OAuth Apps). + token_response = {"access_token": "access123", "refresh_token": "refresh123"} + + await manager._store_authorization_state(gateway_id, state, code_verifier="cv-no-exp", app_user_email="test@example.com") + + with patch.object(manager, "_exchange_code_for_tokens", return_value=token_response): + with patch.object(manager, "_extract_user_id", return_value="user123"): + with patch.object(mock_token_storage, "store_tokens", new_callable=AsyncMock) as mock_store_tokens: + mock_store_tokens.return_value = Mock(expires_at=None) + await manager.complete_authorization_code_flow(gateway_id, code, state, credentials) + + mock_store_tokens.assert_called_once() + assert mock_store_tokens.call_args.kwargs["expires_in"] is None + @pytest.mark.asyncio async def test_complete_authorization_code_flow_invalid_state(self): """Test authorization code flow completion with invalid state."""