|
1 | | -"""Probe MCP server for OAuth and raise 401 with WWW-Authenticate when required.""" |
| 1 | +"""Probe MCP servers for OAuth and raise 401 with WWW-Authenticate when required. |
2 | 2 |
|
| 3 | +Used by endpoints that call MCP-backed services so clients receive a proper |
| 4 | +401 with WWW-Authenticate when an MCP server requires OAuth. |
| 5 | +""" |
| 6 | + |
| 7 | +import asyncio |
3 | 8 | from typing import Optional |
| 9 | + |
4 | 10 | import aiohttp |
5 | 11 | from fastapi import HTTPException |
6 | 12 |
|
7 | 13 | from models.responses import UnauthorizedResponse |
8 | 14 |
|
| 15 | +from configuration import AppConfig |
| 16 | +from utils.mcp_headers import McpHeaders |
| 17 | +import constants |
| 18 | + |
9 | 19 | from log import get_logger |
10 | 20 |
|
11 | 21 | logger = get_logger(__name__) |
12 | 22 |
|
13 | 23 |
|
14 | | -async def probe_mcp_oauth_and_raise_401( |
| 24 | +async def check_mcp_auth(configuration: AppConfig, mcp_headers: McpHeaders) -> None: |
| 25 | + """Probe each configured MCP server that expects OAuth or has auth headers. |
| 26 | +
|
| 27 | + For every MCP server that has an Authorization header in mcp_headers or |
| 28 | + has OAuth in its resolved_authorization_headers, performs a probe request. |
| 29 | + If the server indicates OAuth is required, raises 401 with |
| 30 | + WWW-Authenticate (or 401 without header on probe failure). |
| 31 | +
|
| 32 | + Parameters: |
| 33 | + configuration: Application config containing mcp_servers. |
| 34 | + mcp_headers: Per-server headers; keys are MCP server names. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + None when no server requires OAuth or probe does not trigger 401. |
| 38 | +
|
| 39 | + Raises: |
| 40 | + HTTPException: 401 when an MCP server requires OAuth (from probe_mcp). |
| 41 | + """ |
| 42 | + probes = [] |
| 43 | + for mcp_server in configuration.mcp_servers: |
| 44 | + headers = mcp_headers.get(mcp_server.name, {}) |
| 45 | + authorization = headers.get("Authorization", None) |
| 46 | + if ( |
| 47 | + authorization |
| 48 | + or constants.MCP_AUTH_OAUTH |
| 49 | + in mcp_server.resolved_authorization_headers.values() |
| 50 | + ): |
| 51 | + probes.append(probe_mcp(mcp_server.url, authorization=authorization)) |
| 52 | + if probes: |
| 53 | + await asyncio.gather(*probes) |
| 54 | + |
| 55 | + |
| 56 | +async def probe_mcp( |
15 | 57 | url: str, |
16 | | - chain_from: Optional[BaseException] = None, |
| 58 | + authorization: Optional[str] = None, |
17 | 59 | ) -> None: |
18 | 60 | """Probe MCP endpoint and raise 401 so the client can perform OAuth. |
19 | 61 |
|
20 | | - Performs an async GET to the given URL to obtain a WWW-Authenticate header, |
21 | | - then raises HTTPException with status 401 and that header. If the probe |
22 | | - fails (connection error, timeout), raises 401 without the header. |
| 62 | + Performs an async GET to the given URL. If the response is 401 with |
| 63 | + WWW-Authenticate, raises HTTPException with that header. If the response |
| 64 | + is 401 without the header, or the probe fails (connection error, timeout), |
| 65 | + raises 401 without WWW-Authenticate. |
23 | 66 |
|
24 | | - Args: |
| 67 | + Parameters: |
25 | 68 | url: MCP server URL to probe. |
26 | | - chain_from: Exception to chain the HTTPException from when |
27 | | - the probe succeeds (e.g. the original AuthenticationError). |
| 69 | + authorization: Optional Authorization header value for the probe request. |
28 | 70 |
|
29 | 71 | Returns: |
30 | | - None. Always raises an HTTPException. |
| 72 | + None when the server responds with a status other than 401 (OAuth not |
| 73 | + required). Otherwise does not return; raises HTTPException. |
31 | 74 |
|
32 | 75 | Raises: |
33 | | - HTTPException: 401 with WWW-Authenticate when the probe succeeds, or |
34 | | - 401 without the header when the probe fails. |
| 76 | + HTTPException: 401 with WWW-Authenticate when the server returns 401 |
| 77 | + and includes that header; 401 without the header when the server |
| 78 | + returns 401 without it or when the probe fails (timeout/connection). |
35 | 79 | """ |
36 | 80 | cause = f"MCP server at {url} requires OAuth" |
37 | 81 | error_response = UnauthorizedResponse(cause=cause) |
| 82 | + headers: Optional[dict[str, str]] = ( |
| 83 | + {"authorization": authorization} if authorization is not None else None |
| 84 | + ) |
38 | 85 | try: |
39 | 86 | timeout = aiohttp.ClientTimeout(total=10) |
40 | 87 | async with aiohttp.ClientSession(timeout=timeout) as session: |
41 | | - async with session.get(url) as resp: |
| 88 | + async with session.get(url, headers=headers) as resp: |
| 89 | + if resp.status != 401: |
| 90 | + return |
42 | 91 | www_auth = resp.headers.get("WWW-Authenticate") |
43 | 92 | if www_auth is None: |
44 | 93 | logger.warning("No WWW-Authenticate header received from %s", url) |
45 | | - raise HTTPException(**error_response.model_dump()) from chain_from |
| 94 | + raise HTTPException(**error_response.model_dump()) |
46 | 95 | raise HTTPException( |
47 | 96 | **error_response.model_dump(), |
48 | 97 | headers={"WWW-Authenticate": www_auth}, |
49 | | - ) from chain_from |
| 98 | + ) |
50 | 99 | except (aiohttp.ClientError, TimeoutError) as probe_err: |
51 | 100 | logger.warning("OAuth probe failed for %s: %s", url, probe_err) |
52 | 101 | raise HTTPException(**error_response.model_dump()) from probe_err |
0 commit comments