-
Notifications
You must be signed in to change notification settings - Fork 94
LCORE-1420: Fixing MCP Authorization #1278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,8 +3,7 @@ | |
| from typing import Annotated, Any | ||
|
|
||
| from fastapi import APIRouter, Depends, HTTPException, Request | ||
| from llama_stack_client import APIConnectionError, BadRequestError, AuthenticationError | ||
| from llama_stack.core.datatypes import AuthenticationRequiredError | ||
| from llama_stack_client import APIConnectionError, BadRequestError | ||
|
|
||
| from authentication import get_auth_dependency | ||
| from authentication.interface import AuthTuple | ||
|
|
@@ -21,7 +20,7 @@ | |
| ) | ||
| from utils.endpoints import check_configuration_loaded | ||
| from utils.mcp_headers import McpHeaders, mcp_headers_dependency | ||
| from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401 | ||
| from utils.mcp_oauth_probe import check_mcp_auth | ||
| from utils.tool_formatter import format_tools_list | ||
| from log import get_logger | ||
|
|
||
|
|
@@ -124,6 +123,8 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st | |
|
|
||
| check_configuration_loaded(configuration) | ||
|
|
||
| await check_mcp_auth(configuration, mcp_headers) | ||
|
|
||
| toolgroups_response = [] | ||
| try: | ||
| client = AsyncLlamaStackClientHolder().get_client() | ||
|
|
@@ -146,6 +147,7 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st | |
| # Get tools for each toolgroup | ||
| headers = mcp_headers.get(toolgroup.identifier, {}) | ||
| authorization = headers.pop("Authorization", None) | ||
|
|
||
|
Comment on lines
148
to
+150
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid mutating request-scoped MCP headers in-place. Line 149 uses ♻️ Proposed fix- headers = mcp_headers.get(toolgroup.identifier, {})
- authorization = headers.pop("Authorization", None)
+ original_headers = mcp_headers.get(toolgroup.identifier, {})
+ authorization = next(
+ (v for k, v in original_headers.items() if k.lower() == "authorization"),
+ None,
+ )
+ headers = {
+ k: v for k, v in original_headers.items() if k.lower() != "authorization"
+ }As per coding guidelines "Avoid in-place parameter modification anti-patterns; return new data structures instead." 🤖 Prompt for AI Agents |
||
| tools_response = await client.tools.list( | ||
| toolgroup_id=toolgroup.identifier, | ||
| extra_headers=headers, | ||
|
|
@@ -154,13 +156,6 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st | |
| except BadRequestError: | ||
| logger.error("Toolgroup %s is not found", toolgroup.identifier) | ||
| continue | ||
| except (AuthenticationError, AuthenticationRequiredError) as e: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we handled auth from llamastack itself as well, not just from MCP Oath. Now the check_mcp_auth only covers the MCP Oauth scenario. So if client.tools.list() now gives AuthenticationError for any problem that is not MCP-related, it now puts forward a 500 error instead of 401. A simple fix could be: re-add catch for AuthenticationError (to return 401 again) -- the deleted lines 157, 162, and 163:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The only reason I got rid of the llama stack authentication error catching in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I might be able to add an exception handler in the
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
But should I add this if its basically an impossible to reach branch?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed that currently only MCP servers trigger auth in llamastack. OTOH, it might be worth it as defense-in-depth -- if llamastack adds auth for non-MCP reason we would get 500 errors again. It's a low risk though, so up to you 🤷
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I ended up not implementing this error catching mainly to reduce size of this PR as well as not add any (as of right now) unneeded complexity. If this becomes an issue we will address it at that time :). However with the toolgroup depreciation large changes will likely begin to emerge that will likely make any work done with this exception catching irrelevant so for the time being I don't think its necessary. |
||
| if toolgroup.mcp_endpoint: | ||
| await probe_mcp_oauth_and_raise_401( | ||
| toolgroup.mcp_endpoint.uri, chain_from=e | ||
| ) | ||
| error_response = UnauthorizedResponse(cause=str(e)) | ||
| raise HTTPException(**error_response.model_dump()) from e | ||
| except APIConnectionError as e: | ||
| logger.error("Unable to connect to Llama Stack: %s", e) | ||
| response = ServiceUnavailableResponse( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,52 +1,101 @@ | ||
| """Probe MCP server for OAuth and raise 401 with WWW-Authenticate when required.""" | ||
| """Probe MCP servers for OAuth and raise 401 with WWW-Authenticate when required. | ||
|
|
||
| Used by endpoints that call MCP-backed services so clients receive a proper | ||
| 401 with WWW-Authenticate when an MCP server requires OAuth. | ||
| """ | ||
|
|
||
| import asyncio | ||
| from typing import Optional | ||
|
|
||
| import aiohttp | ||
| from fastapi import HTTPException | ||
|
|
||
| from models.responses import UnauthorizedResponse | ||
|
|
||
| from configuration import AppConfig | ||
| from utils.mcp_headers import McpHeaders | ||
| import constants | ||
|
|
||
| from log import get_logger | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| async def probe_mcp_oauth_and_raise_401( | ||
| async def check_mcp_auth(configuration: AppConfig, mcp_headers: McpHeaders) -> None: | ||
| """Probe each configured MCP server that expects OAuth or has auth headers. | ||
|
|
||
| For every MCP server that has an Authorization header in mcp_headers or | ||
| has OAuth in its resolved_authorization_headers, performs a probe request. | ||
| If the server indicates OAuth is required, raises 401 with | ||
| WWW-Authenticate (or 401 without header on probe failure). | ||
|
|
||
| Parameters: | ||
| configuration: Application config containing mcp_servers. | ||
| mcp_headers: Per-server headers; keys are MCP server names. | ||
|
|
||
| Returns: | ||
| None when no server requires OAuth or probe does not trigger 401. | ||
|
|
||
| Raises: | ||
| HTTPException: 401 when an MCP server requires OAuth (from probe_mcp). | ||
| """ | ||
| probes = [] | ||
| for mcp_server in configuration.mcp_servers: | ||
|
jrobertboos marked this conversation as resolved.
|
||
| headers = mcp_headers.get(mcp_server.name, {}) | ||
| authorization = headers.get("Authorization", None) | ||
|
Comment on lines
+44
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle Line 43 only checks ♻️ Proposed fix- authorization = headers.get("Authorization", None)
+ authorization = next(
+ (value for key, value in headers.items() if key.lower() == "authorization"),
+ None,
+ )🤖 Prompt for AI Agents |
||
| if ( | ||
| authorization | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will pass with empty header.... maybe |
||
| or constants.MCP_AUTH_OAUTH | ||
| in mcp_server.resolved_authorization_headers.values() | ||
| ): | ||
| probes.append(probe_mcp(mcp_server.url, authorization=authorization)) | ||
| if probes: | ||
| await asyncio.gather(*probes) | ||
|
|
||
|
|
||
| async def probe_mcp( | ||
| url: str, | ||
| chain_from: Optional[BaseException] = None, | ||
| authorization: Optional[str] = None, | ||
| ) -> None: | ||
| """Probe MCP endpoint and raise 401 so the client can perform OAuth. | ||
|
|
||
| Performs an async GET to the given URL to obtain a WWW-Authenticate header, | ||
| then raises HTTPException with status 401 and that header. If the probe | ||
| fails (connection error, timeout), raises 401 without the header. | ||
| Performs an async GET to the given URL. If the response is 401 with | ||
| WWW-Authenticate, raises HTTPException with that header. If the response | ||
| is 401 without the header, or the probe fails (connection error, timeout), | ||
| raises 401 without WWW-Authenticate. | ||
|
|
||
| Args: | ||
| Parameters: | ||
| url: MCP server URL to probe. | ||
| chain_from: Exception to chain the HTTPException from when | ||
| the probe succeeds (e.g. the original AuthenticationError). | ||
| authorization: Optional Authorization header value for the probe request. | ||
|
|
||
| Returns: | ||
| None. Always raises an HTTPException. | ||
| None when the server responds with a status other than 401 (OAuth not | ||
| required). Otherwise does not return; raises HTTPException. | ||
|
|
||
| Raises: | ||
| HTTPException: 401 with WWW-Authenticate when the probe succeeds, or | ||
| 401 without the header when the probe fails. | ||
| HTTPException: 401 with WWW-Authenticate when the server returns 401 | ||
| and includes that header; 401 without the header when the server | ||
| returns 401 without it or when the probe fails (timeout/connection). | ||
| """ | ||
| cause = f"MCP server at {url} requires OAuth" | ||
| error_response = UnauthorizedResponse(cause=cause) | ||
| headers: Optional[dict[str, str]] = ( | ||
| {"authorization": authorization} if authorization is not None else None | ||
| ) | ||
| try: | ||
| timeout = aiohttp.ClientTimeout(total=10) | ||
| async with aiohttp.ClientSession(timeout=timeout) as session: | ||
| async with session.get(url) as resp: | ||
| async with session.get(url, headers=headers) as resp: | ||
| if resp.status != 401: | ||
| return | ||
| www_auth = resp.headers.get("WWW-Authenticate") | ||
| if www_auth is None: | ||
| logger.warning("No WWW-Authenticate header received from %s", url) | ||
| raise HTTPException(**error_response.model_dump()) from chain_from | ||
| raise HTTPException(**error_response.model_dump()) | ||
| raise HTTPException( | ||
| **error_response.model_dump(), | ||
| headers={"WWW-Authenticate": www_auth}, | ||
| ) from chain_from | ||
| ) | ||
| except (aiohttp.ClientError, TimeoutError) as probe_err: | ||
| logger.warning("OAuth probe failed for %s: %s", url, probe_err) | ||
| raise HTTPException(**error_response.model_dump()) from probe_err | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: lightspeed-core/lightspeed-stack
Length of output: 932
🏁 Script executed:
Repository: lightspeed-core/lightspeed-stack
Length of output: 3859
🏁 Script executed:
Repository: lightspeed-core/lightspeed-stack
Length of output: 3403
🏁 Script executed:
Repository: lightspeed-core/lightspeed-stack
Length of output: 8827
🏁 Script executed:
Repository: lightspeed-core/lightspeed-stack
Length of output: 1031
🏁 Script executed:
Repository: lightspeed-core/lightspeed-stack
Length of output: 1623
Add CancelledError handling to non-streaming query path to prevent 500 errors during request cancellation.
The non-streaming query endpoint lacks explicit cancellation handling for
get_topic_summary()(line 222) andresponses.create()(withinretrieve_response()). If the client cancels the request during either call, the unhandledCancelledErrorpropagates to the ASGI layer, resulting in a 500 error. The streaming endpoint handles this correctly with an explicitexcept asyncio.CancelledError:handler (seestreaming_query.py:490), but the non-streaming path needs equivalent protection to prevent 500 responses when authenticated MCP requests are interrupted.🤖 Prompt for AI Agents