Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions mcpgateway/transports/streamablehttp_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -3642,12 +3642,51 @@ async def _auth_jwt(self, *, token: str) -> bool:
else:
_record_mcp_auth_cache_event("team_membership_cache_hit")

# Detect API token auth and update last_used timestamp.
# API tokens carry user.auth_provider == "api_token" in their JWT payload.
# Legacy tokens (created before auth_provider was added) are detected via DB lookup.
_nested_user = user_payload.get("user", {})
_auth_provider = _nested_user.get("auth_provider") if isinstance(_nested_user, dict) else None
_is_api_token = _auth_provider == "api_token"

if not _is_api_token and jti and _auth_provider is None:
# Legacy API token fallback: only when auth_provider is absent (pre-auth_provider tokens).
# Tokens with an explicit auth_provider (email, oauth, saml, etc.) are never legacy API tokens.
try:
# First-Party
from mcpgateway.auth import _is_api_token_jti_sync # pylint: disable=import-outside-toplevel

_is_api_token = await asyncio.to_thread(_is_api_token_jti_sync, jti)
except Exception:
pass # Best-effort detection; default to "jwt" auth_method

resolved_auth_method = "api_token" if _is_api_token else "jwt"

# Update last_used timestamp for API tokens (rate-limited internally)
if _is_api_token and jti:
try:
# First-Party
from mcpgateway.auth import _update_api_token_last_used_sync # pylint: disable=import-outside-toplevel

await asyncio.to_thread(_update_api_token_last_used_sync, jti)
except Exception:
logger.debug("Failed to update API token last_used in MCP auth for jti=...%s", jti[-8:] if jti else "")

# Propagate auth_method and jti into ASGI scope state so that
# TokenUsageMiddleware can recognise API token requests and log usage.
state = self.scope.setdefault("state", {})
state["auth_method"] = resolved_auth_method
if _is_api_token and jti:
state["jti"] = jti
if user_email:
state["user_email"] = user_email

auth_user_ctx: dict[str, Any] = {
"email": user_email,
"teams": final_teams,
"is_authenticated": True,
"is_admin": is_admin,
"auth_method": "jwt",
"auth_method": resolved_auth_method,
"permission_is_admin": db_user_is_admin or is_admin,
"token_use": token_use, # propagated for downstream RBAC (check_any_team)
}
Expand All @@ -3667,7 +3706,7 @@ async def _auth_jwt(self, *, token: str) -> bool:
final_teams,
user_email=user_email,
is_admin=bool(db_user_is_admin or is_admin),
auth_method="jwt",
auth_method=resolved_auth_method,
team_name=trace_team_name,
)
except HTTPException:
Expand Down
102 changes: 102 additions & 0 deletions tests/playwright/security/test_mcp_transport_auth_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# Standard
from contextlib import suppress
import time
from urllib.parse import urlparse
import uuid

Expand Down Expand Up @@ -170,3 +171,104 @@ def test_websocket_auth_handshake_behavior(self):

assert isinstance(response, str)
assert "Parse error" in response or "jsonrpc" in response


@pytest.fixture
def api_token_info(admin_api: APIRequestContext) -> dict:
"""Create an API token and return its access_token and token_id, then clean up."""
resp = admin_api.post("/tokens", data={"name": f"last-used-test-{uuid.uuid4().hex[:8]}", "expires_in_days": 1})
if resp.status == 404:
pytest.skip("/tokens endpoint unavailable")
assert resp.status in (200, 201), f"Failed to create token: {resp.status} {resp.text()}"
payload = resp.json()
token_obj = payload.get("token", payload)
info = {
"access_token": payload["access_token"],
"token_id": token_obj.get("id") or token_obj.get("token_id"),
}
yield info
with suppress(Exception):
admin_api.delete(f"/tokens/{info['token_id']}")


class TestApiTokenLastUsedViaMCP:
"""Verify API token last_used is updated when accessing virtual servers via MCP Streamable HTTP."""

def test_mcp_streamable_http_updates_last_used(self, admin_api: APIRequestContext, playwright: Playwright, public_server_id: str, api_token_info: dict):
"""Accessing /servers/{id}/mcp with an API token should update last_used."""
token_id = api_token_info["token_id"]

# 1. Check initial last_used (should be None for new token)
detail = admin_api.get(f"/tokens/{token_id}")
if detail.status == 404:
pytest.skip("Token detail endpoint unavailable")
initial_last_used = detail.json().get("last_used")

# 2. Make MCP Streamable HTTP request with the API token
token_ctx = _api_context(playwright, api_token_info["access_token"])
try:
mcp_resp = token_ctx.post(
f"/servers/{public_server_id}/mcp",
data={"jsonrpc": "2.0", "id": "1", "method": "initialize", "params": {"protocolVersion": "2025-03-26", "capabilities": {}, "clientInfo": {"name": "e2e-test", "version": "1.0.0"}}},
headers={"Content-Type": "application/json", "Accept": "application/json, text/event-stream"},
)
finally:
token_ctx.dispose()

if mcp_resp.status == 404:
pytest.skip("Streamable HTTP endpoint unavailable")
assert mcp_resp.status != 401, f"API token auth rejected: {mcp_resp.text()}"

# 3. Verify last_used was updated
time.sleep(1) # Allow propagation across multi-gateway setup
detail2 = admin_api.get(f"/tokens/{token_id}")
updated_last_used = detail2.json().get("last_used")

assert updated_last_used is not None, f"last_used not updated after MCP access. Initial: {initial_last_used}, After: {updated_last_used}"

def test_mcp_requests_accumulate_in_token_usage_stats(self, admin_api: APIRequestContext, playwright: Playwright, public_server_id: str, api_token_info: dict):
"""Multiple MCP requests with an API token should be accurately reflected in usage statistics.

Uses a fresh per-test token (api_token_info fixture) so stats are isolated β€” no other
requests can contribute to this token's usage counters.
"""
token_id = api_token_info["token_id"]
num_requests = 5

# 1. Verify fresh token has zero usage
usage_resp = admin_api.get(f"/tokens/{token_id}/usage")
if usage_resp.status == 404:
pytest.skip("Token usage endpoint unavailable")
baseline = usage_resp.json()
assert baseline.get("total_requests", 0) == 0, "Fresh token should have zero usage"

# 2. Make exactly N MCP requests
token_ctx = _api_context(playwright, api_token_info["access_token"])
try:
for i in range(num_requests):
token_ctx.post(
f"/servers/{public_server_id}/mcp",
data={"jsonrpc": "2.0", "id": str(i + 1), "method": "ping", "params": {}},
headers={"Content-Type": "application/json", "Accept": "application/json, text/event-stream"},
)
finally:
token_ctx.dispose()

# 3. Allow async logging to complete
time.sleep(1)

# 4. Verify usage stats with strict equality (isolated per-test token)
usage_resp2 = admin_api.get(f"/tokens/{token_id}/usage")
assert usage_resp2.status == 200
stats = usage_resp2.json()

assert stats["total_requests"] == num_requests, f"Expected exactly {num_requests} total requests, got {stats['total_requests']}"
assert stats["successful_requests"] == num_requests, f"Expected exactly {num_requests} successful requests, got {stats['successful_requests']}"
assert stats["blocked_requests"] == 0, f"Expected 0 blocked requests, got {stats['blocked_requests']}"
assert stats["success_rate"] == 1.0, f"Expected 100% success rate, got {stats['success_rate']}"
assert stats["average_response_time_ms"] > 0, f"Expected positive average response time, got {stats['average_response_time_ms']}ms"

# Top endpoints should include the MCP server path
endpoint_paths = [ep[0] if isinstance(ep, (list, tuple)) else ep for ep in stats.get("top_endpoints", [])]
has_mcp_endpoint = any("/mcp" in path for path in endpoint_paths)
assert has_mcp_endpoint, f"Expected /mcp in top_endpoints, got {endpoint_paths}"
Loading