diff --git a/mcpgateway/transports/streamablehttp_transport.py b/mcpgateway/transports/streamablehttp_transport.py index 3af6b07020..1c11715ce6 100644 --- a/mcpgateway/transports/streamablehttp_transport.py +++ b/mcpgateway/transports/streamablehttp_transport.py @@ -3642,12 +3642,51 @@ async def _auth_jwt(self, *, token: str) -> bool: else: _record_mcp_auth_cache_event("team_membership_cache_hit") + # Detect API token auth and update last_used timestamp. + # API tokens carry user.auth_provider == "api_token" in their JWT payload. + # Legacy tokens (created before auth_provider was added) are detected via DB lookup. + _nested_user = user_payload.get("user", {}) + _auth_provider = _nested_user.get("auth_provider") if isinstance(_nested_user, dict) else None + _is_api_token = _auth_provider == "api_token" + + if not _is_api_token and jti and _auth_provider is None: + # Legacy API token fallback: only when auth_provider is absent (pre-auth_provider tokens). + # Tokens with an explicit auth_provider (email, oauth, saml, etc.) are never legacy API tokens. + try: + # First-Party + from mcpgateway.auth import _is_api_token_jti_sync # pylint: disable=import-outside-toplevel + + _is_api_token = await asyncio.to_thread(_is_api_token_jti_sync, jti) + except Exception: + pass # Best-effort detection; default to "jwt" auth_method + + resolved_auth_method = "api_token" if _is_api_token else "jwt" + + # Update last_used timestamp for API tokens (rate-limited internally) + if _is_api_token and jti: + try: + # First-Party + from mcpgateway.auth import _update_api_token_last_used_sync # pylint: disable=import-outside-toplevel + + await asyncio.to_thread(_update_api_token_last_used_sync, jti) + except Exception: + logger.debug("Failed to update API token last_used in MCP auth for jti=...%s", jti[-8:] if jti else "") + + # Propagate auth_method and jti into ASGI scope state so that + # TokenUsageMiddleware can recognise API token requests and log usage. + state = self.scope.setdefault("state", {}) + state["auth_method"] = resolved_auth_method + if _is_api_token and jti: + state["jti"] = jti + if user_email: + state["user_email"] = user_email + auth_user_ctx: dict[str, Any] = { "email": user_email, "teams": final_teams, "is_authenticated": True, "is_admin": is_admin, - "auth_method": "jwt", + "auth_method": resolved_auth_method, "permission_is_admin": db_user_is_admin or is_admin, "token_use": token_use, # propagated for downstream RBAC (check_any_team) } @@ -3667,7 +3706,7 @@ async def _auth_jwt(self, *, token: str) -> bool: final_teams, user_email=user_email, is_admin=bool(db_user_is_admin or is_admin), - auth_method="jwt", + auth_method=resolved_auth_method, team_name=trace_team_name, ) except HTTPException: diff --git a/tests/playwright/security/test_mcp_transport_auth_matrix.py b/tests/playwright/security/test_mcp_transport_auth_matrix.py index bad7b1d3ce..9713f8eebd 100644 --- a/tests/playwright/security/test_mcp_transport_auth_matrix.py +++ b/tests/playwright/security/test_mcp_transport_auth_matrix.py @@ -9,6 +9,7 @@ # Standard from contextlib import suppress +import time from urllib.parse import urlparse import uuid @@ -170,3 +171,104 @@ def test_websocket_auth_handshake_behavior(self): assert isinstance(response, str) assert "Parse error" in response or "jsonrpc" in response + + +@pytest.fixture +def api_token_info(admin_api: APIRequestContext) -> dict: + """Create an API token and return its access_token and token_id, then clean up.""" + resp = admin_api.post("/tokens", data={"name": f"last-used-test-{uuid.uuid4().hex[:8]}", "expires_in_days": 1}) + if resp.status == 404: + pytest.skip("/tokens endpoint unavailable") + assert resp.status in (200, 201), f"Failed to create token: {resp.status} {resp.text()}" + payload = resp.json() + token_obj = payload.get("token", payload) + info = { + "access_token": payload["access_token"], + "token_id": token_obj.get("id") or token_obj.get("token_id"), + } + yield info + with suppress(Exception): + admin_api.delete(f"/tokens/{info['token_id']}") + + +class TestApiTokenLastUsedViaMCP: + """Verify API token last_used is updated when accessing virtual servers via MCP Streamable HTTP.""" + + def test_mcp_streamable_http_updates_last_used(self, admin_api: APIRequestContext, playwright: Playwright, public_server_id: str, api_token_info: dict): + """Accessing /servers/{id}/mcp with an API token should update last_used.""" + token_id = api_token_info["token_id"] + + # 1. Check initial last_used (should be None for new token) + detail = admin_api.get(f"/tokens/{token_id}") + if detail.status == 404: + pytest.skip("Token detail endpoint unavailable") + initial_last_used = detail.json().get("last_used") + + # 2. Make MCP Streamable HTTP request with the API token + token_ctx = _api_context(playwright, api_token_info["access_token"]) + try: + mcp_resp = token_ctx.post( + f"/servers/{public_server_id}/mcp", + data={"jsonrpc": "2.0", "id": "1", "method": "initialize", "params": {"protocolVersion": "2025-03-26", "capabilities": {}, "clientInfo": {"name": "e2e-test", "version": "1.0.0"}}}, + headers={"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}, + ) + finally: + token_ctx.dispose() + + if mcp_resp.status == 404: + pytest.skip("Streamable HTTP endpoint unavailable") + assert mcp_resp.status != 401, f"API token auth rejected: {mcp_resp.text()}" + + # 3. Verify last_used was updated + time.sleep(1) # Allow propagation across multi-gateway setup + detail2 = admin_api.get(f"/tokens/{token_id}") + updated_last_used = detail2.json().get("last_used") + + assert updated_last_used is not None, f"last_used not updated after MCP access. Initial: {initial_last_used}, After: {updated_last_used}" + + def test_mcp_requests_accumulate_in_token_usage_stats(self, admin_api: APIRequestContext, playwright: Playwright, public_server_id: str, api_token_info: dict): + """Multiple MCP requests with an API token should be accurately reflected in usage statistics. + + Uses a fresh per-test token (api_token_info fixture) so stats are isolated — no other + requests can contribute to this token's usage counters. + """ + token_id = api_token_info["token_id"] + num_requests = 5 + + # 1. Verify fresh token has zero usage + usage_resp = admin_api.get(f"/tokens/{token_id}/usage") + if usage_resp.status == 404: + pytest.skip("Token usage endpoint unavailable") + baseline = usage_resp.json() + assert baseline.get("total_requests", 0) == 0, "Fresh token should have zero usage" + + # 2. Make exactly N MCP requests + token_ctx = _api_context(playwright, api_token_info["access_token"]) + try: + for i in range(num_requests): + token_ctx.post( + f"/servers/{public_server_id}/mcp", + data={"jsonrpc": "2.0", "id": str(i + 1), "method": "ping", "params": {}}, + headers={"Content-Type": "application/json", "Accept": "application/json, text/event-stream"}, + ) + finally: + token_ctx.dispose() + + # 3. Allow async logging to complete + time.sleep(1) + + # 4. Verify usage stats with strict equality (isolated per-test token) + usage_resp2 = admin_api.get(f"/tokens/{token_id}/usage") + assert usage_resp2.status == 200 + stats = usage_resp2.json() + + assert stats["total_requests"] == num_requests, f"Expected exactly {num_requests} total requests, got {stats['total_requests']}" + assert stats["successful_requests"] == num_requests, f"Expected exactly {num_requests} successful requests, got {stats['successful_requests']}" + assert stats["blocked_requests"] == 0, f"Expected 0 blocked requests, got {stats['blocked_requests']}" + assert stats["success_rate"] == 1.0, f"Expected 100% success rate, got {stats['success_rate']}" + assert stats["average_response_time_ms"] > 0, f"Expected positive average response time, got {stats['average_response_time_ms']}ms" + + # Top endpoints should include the MCP server path + endpoint_paths = [ep[0] if isinstance(ep, (list, tuple)) else ep for ep in stats.get("top_endpoints", [])] + has_mcp_endpoint = any("/mcp" in path for path in endpoint_paths) + assert has_mcp_endpoint, f"Expected /mcp in top_endpoints, got {endpoint_paths}" diff --git a/tests/unit/mcpgateway/transports/test_streamablehttp_transport.py b/tests/unit/mcpgateway/transports/test_streamablehttp_transport.py index 1ac828e234..b5b5b72360 100644 --- a/tests/unit/mcpgateway/transports/test_streamablehttp_transport.py +++ b/tests/unit/mcpgateway/transports/test_streamablehttp_transport.py @@ -13365,6 +13365,277 @@ async def test_auth_jwt_sets_trace_context_for_session_token(monkeypatch): clear_trace_context() +# --------------------------------------------------------------------------- +# _auth_jwt — API token last_used tracking +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_auth_jwt_updates_last_used_for_api_token(monkeypatch): + """_auth_jwt should call _update_api_token_last_used_sync for API tokens.""" + jwt_payload = { + "sub": "user@example.com", + "jti": "jti-api-token-123", + "is_admin": False, + "token_use": "api", + "user": {"auth_provider": "api_token"}, + "scopes": {"permissions": ["tools.read"]}, + } + + monkeypatch.setattr(tr.settings, "auth_cache_enabled", False) + monkeypatch.setattr(tr.settings, "auth_cache_batch_queries", False) + monkeypatch.setattr(tr.settings, "require_user_in_db", False) + + mock_update = MagicMock() + user_record = SimpleNamespace(is_admin=False, is_active=True) + + scope = {"type": "http", "path": "/servers/srv-1/mcp", "headers": []} + handler = tr._StreamableHttpAuthHandler(scope, AsyncMock(), AsyncMock()) + + with ( + patch("mcpgateway.transports.streamablehttp_transport.verify_credentials", AsyncMock(return_value=jwt_payload)), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), + patch("mcpgateway.auth._get_user_by_email_sync", return_value=user_record), + patch("mcpgateway.auth.normalize_token_teams", return_value=None), + patch("mcpgateway.auth._update_api_token_last_used_sync", mock_update), + ): + result = await handler._auth_jwt(token="fake-api-token") + + assert result is True + mock_update.assert_called_once_with("jti-api-token-123") + + +@pytest.mark.asyncio +async def test_auth_jwt_sets_api_token_auth_method_in_context_and_scope(monkeypatch): + """_auth_jwt should set auth_method='api_token' in user_context_var and ASGI scope state.""" + jwt_payload = { + "sub": "user@example.com", + "jti": "jti-api-token-456", + "is_admin": True, + "token_use": "api", + "user": {"auth_provider": "api_token"}, + } + + monkeypatch.setattr(tr.settings, "auth_cache_enabled", False) + monkeypatch.setattr(tr.settings, "auth_cache_batch_queries", False) + monkeypatch.setattr(tr.settings, "require_user_in_db", False) + + user_record = SimpleNamespace(is_admin=True, is_active=True) + scope = {"type": "http", "path": "/mcp", "headers": []} + handler = tr._StreamableHttpAuthHandler(scope, AsyncMock(), AsyncMock()) + + with ( + patch("mcpgateway.transports.streamablehttp_transport.verify_credentials", AsyncMock(return_value=jwt_payload)), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), + patch("mcpgateway.auth._get_user_by_email_sync", return_value=user_record), + patch("mcpgateway.auth.normalize_token_teams", return_value=None), + patch("mcpgateway.auth._update_api_token_last_used_sync", MagicMock()), + ): + result = await handler._auth_jwt(token="fake-api-token") + + assert result is True + + # Verify user_context_var + ctx = user_context_var.get() + assert ctx["auth_method"] == "api_token" + assert ctx["email"] == "user@example.com" + + # Verify ASGI scope state propagation + assert scope["state"]["auth_method"] == "api_token" + assert scope["state"]["jti"] == "jti-api-token-456" + assert scope["state"]["user_email"] == "user@example.com" + + +@pytest.mark.asyncio +async def test_auth_jwt_does_not_update_last_used_for_session_token(monkeypatch): + """_auth_jwt should NOT call _update_api_token_last_used_sync for session tokens.""" + jwt_payload = { + "sub": "user@example.com", + "jti": "jti-session-789", + "is_admin": False, + "token_use": "session", + "user": {"auth_provider": "email"}, + } + + monkeypatch.setattr(tr.settings, "auth_cache_enabled", False) + monkeypatch.setattr(tr.settings, "auth_cache_batch_queries", False) + monkeypatch.setattr(tr.settings, "require_user_in_db", False) + + mock_update = MagicMock() + user_record = SimpleNamespace(is_admin=False, is_active=True) + scope = {"type": "http", "path": "/mcp", "headers": []} + handler = tr._StreamableHttpAuthHandler(scope, AsyncMock(), AsyncMock()) + + with ( + patch("mcpgateway.transports.streamablehttp_transport.verify_credentials", AsyncMock(return_value=jwt_payload)), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), + patch("mcpgateway.auth._get_user_by_email_sync", return_value=user_record), + patch("mcpgateway.auth.resolve_session_teams", AsyncMock(return_value=["team-1"])), + ): + result = await handler._auth_jwt(token="fake-session-token") + + assert result is True + mock_update.assert_not_called() + + # Verify auth_method is "jwt", not "api_token" + ctx = user_context_var.get() + assert ctx["auth_method"] == "jwt" + assert scope["state"]["auth_method"] == "jwt" + assert "jti" not in scope.get("state", {}) + + +@pytest.mark.asyncio +async def test_auth_jwt_continues_when_last_used_update_fails(monkeypatch): + """_auth_jwt should continue authentication even if last_used update raises an exception.""" + jwt_payload = { + "sub": "user@example.com", + "jti": "jti-api-token-fail", + "is_admin": True, + "token_use": "api", + "user": {"auth_provider": "api_token"}, + } + + monkeypatch.setattr(tr.settings, "auth_cache_enabled", False) + monkeypatch.setattr(tr.settings, "auth_cache_batch_queries", False) + monkeypatch.setattr(tr.settings, "require_user_in_db", False) + + user_record = SimpleNamespace(is_admin=True, is_active=True) + scope = {"type": "http", "path": "/mcp", "headers": []} + handler = tr._StreamableHttpAuthHandler(scope, AsyncMock(), AsyncMock()) + + def raise_db_error(jti): + raise RuntimeError("Database connection failed") + + with ( + patch("mcpgateway.transports.streamablehttp_transport.verify_credentials", AsyncMock(return_value=jwt_payload)), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), + patch("mcpgateway.auth._get_user_by_email_sync", return_value=user_record), + patch("mcpgateway.auth.normalize_token_teams", return_value=None), + patch("mcpgateway.auth._update_api_token_last_used_sync", raise_db_error), + ): + result = await handler._auth_jwt(token="fake-api-token") + + # Authentication should succeed despite last_used update failure + assert result is True + ctx = user_context_var.get() + assert ctx["auth_method"] == "api_token" + assert ctx["is_authenticated"] is True + + +@pytest.mark.asyncio +async def test_auth_jwt_detects_legacy_api_token_via_db_lookup(monkeypatch): + """_auth_jwt should detect legacy API tokens (no auth_provider) via DB JTI lookup.""" + jwt_payload = { + "sub": "legacy@example.com", + "jti": "jti-legacy-token", + "is_admin": True, + "token_use": "api", + # No "user" key — legacy token format + } + + monkeypatch.setattr(tr.settings, "auth_cache_enabled", False) + monkeypatch.setattr(tr.settings, "auth_cache_batch_queries", False) + monkeypatch.setattr(tr.settings, "require_user_in_db", False) + + mock_update = MagicMock() + user_record = SimpleNamespace(is_admin=True, is_active=True) + scope = {"type": "http", "path": "/mcp", "headers": []} + handler = tr._StreamableHttpAuthHandler(scope, AsyncMock(), AsyncMock()) + + with ( + patch("mcpgateway.transports.streamablehttp_transport.verify_credentials", AsyncMock(return_value=jwt_payload)), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), + patch("mcpgateway.auth._get_user_by_email_sync", return_value=user_record), + patch("mcpgateway.auth.normalize_token_teams", return_value=None), + patch("mcpgateway.auth._is_api_token_jti_sync", return_value=True), + patch("mcpgateway.auth._update_api_token_last_used_sync", mock_update), + ): + result = await handler._auth_jwt(token="fake-legacy-token") + + assert result is True + mock_update.assert_called_once_with("jti-legacy-token") + + ctx = user_context_var.get() + assert ctx["auth_method"] == "api_token" + assert scope["state"]["auth_method"] == "api_token" + assert scope["state"]["jti"] == "jti-legacy-token" + + +@pytest.mark.asyncio +async def test_auth_jwt_legacy_detection_failure_defaults_to_jwt(monkeypatch): + """_auth_jwt should default to auth_method='jwt' when legacy DB lookup fails.""" + jwt_payload = { + "sub": "user@example.com", + "jti": "jti-unknown-token", + "is_admin": True, + "token_use": "api", + # No "user" key — triggers legacy fallback + } + + monkeypatch.setattr(tr.settings, "auth_cache_enabled", False) + monkeypatch.setattr(tr.settings, "auth_cache_batch_queries", False) + monkeypatch.setattr(tr.settings, "require_user_in_db", False) + + user_record = SimpleNamespace(is_admin=True, is_active=True) + scope = {"type": "http", "path": "/mcp", "headers": []} + handler = tr._StreamableHttpAuthHandler(scope, AsyncMock(), AsyncMock()) + + def raise_error(jti): + raise RuntimeError("DB unavailable") + + with ( + patch("mcpgateway.transports.streamablehttp_transport.verify_credentials", AsyncMock(return_value=jwt_payload)), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), + patch("mcpgateway.auth._get_user_by_email_sync", return_value=user_record), + patch("mcpgateway.auth.normalize_token_teams", return_value=None), + patch("mcpgateway.auth._is_api_token_jti_sync", raise_error), + ): + result = await handler._auth_jwt(token="fake-token") + + assert result is True + + # Should default to "jwt" auth_method when legacy detection fails + ctx = user_context_var.get() + assert ctx["auth_method"] == "jwt" + assert scope["state"]["auth_method"] == "jwt" + # jti should NOT be in scope state since it's not detected as API token + assert "jti" not in scope.get("state", {}) + + +@pytest.mark.asyncio +async def test_auth_jwt_non_api_token_no_jti_in_scope_state(monkeypatch): + """_auth_jwt should not set jti in scope state for non-API tokens (e.g. OAuth, SAML).""" + jwt_payload = { + "sub": "oauth@example.com", + "jti": "jti-oauth-token", + "is_admin": False, + "token_use": "session", + "user": {"auth_provider": "oauth"}, + } + + monkeypatch.setattr(tr.settings, "auth_cache_enabled", False) + monkeypatch.setattr(tr.settings, "auth_cache_batch_queries", False) + monkeypatch.setattr(tr.settings, "require_user_in_db", False) + + user_record = SimpleNamespace(is_admin=False, is_active=True) + scope = {"type": "http", "path": "/mcp", "headers": []} + handler = tr._StreamableHttpAuthHandler(scope, AsyncMock(), AsyncMock()) + + with ( + patch("mcpgateway.transports.streamablehttp_transport.verify_credentials", AsyncMock(return_value=jwt_payload)), + patch("mcpgateway.auth._check_token_revoked_sync", return_value=False), + patch("mcpgateway.auth._get_user_by_email_sync", return_value=user_record), + patch("mcpgateway.auth.resolve_session_teams", AsyncMock(return_value=["team-oauth"])), + ): + result = await handler._auth_jwt(token="fake-oauth-token") + + assert result is True + ctx = user_context_var.get() + assert ctx["auth_method"] == "jwt" + assert "jti" not in scope.get("state", {}) + assert scope["state"]["user_email"] == "oauth@example.com" + + def test_maybe_open_initialize_span_returns_none_for_non_initialize(): """Non-initialize JSON-RPC payloads should not create transport spans.""" body = tr.orjson.dumps({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}})