Skip to content

Commit ad686e4

Browse files
authored
Merge pull request #1278 from jrobertboos/lcore-1420
LCORE-1420: Fixing MCP Authorization
2 parents 8c003d8 + ae48ed0 commit ad686e4

10 files changed

Lines changed: 102 additions & 222 deletions

File tree

src/app/endpoints/query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
validate_and_retrieve_conversation,
4242
)
4343
from utils.mcp_headers import McpHeaders, mcp_headers_dependency
44+
from utils.mcp_oauth_probe import check_mcp_auth
4445
from utils.query import (
4546
consume_query_tokens,
4647
handle_known_apistatus_errors,
@@ -122,6 +123,8 @@ async def query_endpoint_handler(
122123
"""
123124
check_configuration_loaded(configuration)
124125

126+
await check_mcp_auth(configuration, mcp_headers)
127+
125128
started_at = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
126129
user_id, _, _skip_userid_check, token = auth
127130
# Check token availability

src/app/endpoints/streaming_query.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
validate_and_retrieve_conversation,
6363
)
6464
from utils.mcp_headers import McpHeaders, mcp_headers_dependency
65+
from utils.mcp_oauth_probe import check_mcp_auth
6566
from utils.query import (
6667
consume_query_tokens,
6768
extract_provider_and_model_from_model_id,
@@ -152,6 +153,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
152153
"""
153154
check_configuration_loaded(configuration)
154155

156+
await check_mcp_auth(configuration, mcp_headers)
157+
155158
user_id, _user_name, _skip_userid_check, token = auth
156159
started_at = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
157160

src/app/endpoints/tools.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from typing import Annotated, Any
44

55
from fastapi import APIRouter, Depends, HTTPException, Request
6-
from llama_stack_client import APIConnectionError, BadRequestError, AuthenticationError
7-
from llama_stack.core.datatypes import AuthenticationRequiredError
6+
from llama_stack_client import APIConnectionError, BadRequestError
87

98
from authentication import get_auth_dependency
109
from authentication.interface import AuthTuple
@@ -21,7 +20,7 @@
2120
)
2221
from utils.endpoints import check_configuration_loaded
2322
from utils.mcp_headers import McpHeaders, mcp_headers_dependency
24-
from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401
23+
from utils.mcp_oauth_probe import check_mcp_auth
2524
from utils.tool_formatter import format_tools_list
2625
from log import get_logger
2726

@@ -124,6 +123,8 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
124123

125124
check_configuration_loaded(configuration)
126125

126+
await check_mcp_auth(configuration, mcp_headers)
127+
127128
toolgroups_response = []
128129
try:
129130
client = AsyncLlamaStackClientHolder().get_client()
@@ -146,6 +147,7 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
146147
# Get tools for each toolgroup
147148
headers = mcp_headers.get(toolgroup.identifier, {})
148149
authorization = headers.pop("Authorization", None)
150+
149151
tools_response = await client.tools.list(
150152
toolgroup_id=toolgroup.identifier,
151153
extra_headers=headers,
@@ -154,13 +156,6 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
154156
except BadRequestError:
155157
logger.error("Toolgroup %s is not found", toolgroup.identifier)
156158
continue
157-
except (AuthenticationError, AuthenticationRequiredError) as e:
158-
if toolgroup.mcp_endpoint:
159-
await probe_mcp_oauth_and_raise_401(
160-
toolgroup.mcp_endpoint.uri, chain_from=e
161-
)
162-
error_response = UnauthorizedResponse(cause=str(e))
163-
raise HTTPException(**error_response.model_dump()) from e
164159
except APIConnectionError as e:
165160
logger.error("Unable to connect to Llama Stack: %s", e)
166161
response = ServiceUnavailableResponse(

src/utils/mcp_oauth_probe.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,101 @@
1-
"""Probe MCP server for OAuth and raise 401 with WWW-Authenticate when required."""
1+
"""Probe MCP servers for OAuth and raise 401 with WWW-Authenticate when required.
22
3+
Used by endpoints that call MCP-backed services so clients receive a proper
4+
401 with WWW-Authenticate when an MCP server requires OAuth.
5+
"""
6+
7+
import asyncio
38
from typing import Optional
9+
410
import aiohttp
511
from fastapi import HTTPException
612

713
from models.responses import UnauthorizedResponse
814

15+
from configuration import AppConfig
16+
from utils.mcp_headers import McpHeaders
17+
import constants
18+
919
from log import get_logger
1020

1121
logger = get_logger(__name__)
1222

1323

14-
async def probe_mcp_oauth_and_raise_401(
24+
async def check_mcp_auth(configuration: AppConfig, mcp_headers: McpHeaders) -> None:
25+
"""Probe each configured MCP server that expects OAuth or has auth headers.
26+
27+
For every MCP server that has an Authorization header in mcp_headers or
28+
has OAuth in its resolved_authorization_headers, performs a probe request.
29+
If the server indicates OAuth is required, raises 401 with
30+
WWW-Authenticate (or 401 without header on probe failure).
31+
32+
Parameters:
33+
configuration: Application config containing mcp_servers.
34+
mcp_headers: Per-server headers; keys are MCP server names.
35+
36+
Returns:
37+
None when no server requires OAuth or probe does not trigger 401.
38+
39+
Raises:
40+
HTTPException: 401 when an MCP server requires OAuth (from probe_mcp).
41+
"""
42+
probes = []
43+
for mcp_server in configuration.mcp_servers:
44+
headers = mcp_headers.get(mcp_server.name, {})
45+
authorization = headers.get("Authorization", None)
46+
if (
47+
authorization
48+
or constants.MCP_AUTH_OAUTH
49+
in mcp_server.resolved_authorization_headers.values()
50+
):
51+
probes.append(probe_mcp(mcp_server.url, authorization=authorization))
52+
if probes:
53+
await asyncio.gather(*probes)
54+
55+
56+
async def probe_mcp(
1557
url: str,
16-
chain_from: Optional[BaseException] = None,
58+
authorization: Optional[str] = None,
1759
) -> None:
1860
"""Probe MCP endpoint and raise 401 so the client can perform OAuth.
1961
20-
Performs an async GET to the given URL to obtain a WWW-Authenticate header,
21-
then raises HTTPException with status 401 and that header. If the probe
22-
fails (connection error, timeout), raises 401 without the header.
62+
Performs an async GET to the given URL. If the response is 401 with
63+
WWW-Authenticate, raises HTTPException with that header. If the response
64+
is 401 without the header, or the probe fails (connection error, timeout),
65+
raises 401 without WWW-Authenticate.
2366
24-
Args:
67+
Parameters:
2568
url: MCP server URL to probe.
26-
chain_from: Exception to chain the HTTPException from when
27-
the probe succeeds (e.g. the original AuthenticationError).
69+
authorization: Optional Authorization header value for the probe request.
2870
2971
Returns:
30-
None. Always raises an HTTPException.
72+
None when the server responds with a status other than 401 (OAuth not
73+
required). Otherwise does not return; raises HTTPException.
3174
3275
Raises:
33-
HTTPException: 401 with WWW-Authenticate when the probe succeeds, or
34-
401 without the header when the probe fails.
76+
HTTPException: 401 with WWW-Authenticate when the server returns 401
77+
and includes that header; 401 without the header when the server
78+
returns 401 without it or when the probe fails (timeout/connection).
3579
"""
3680
cause = f"MCP server at {url} requires OAuth"
3781
error_response = UnauthorizedResponse(cause=cause)
82+
headers: Optional[dict[str, str]] = (
83+
{"authorization": authorization} if authorization is not None else None
84+
)
3885
try:
3986
timeout = aiohttp.ClientTimeout(total=10)
4087
async with aiohttp.ClientSession(timeout=timeout) as session:
41-
async with session.get(url) as resp:
88+
async with session.get(url, headers=headers) as resp:
89+
if resp.status != 401:
90+
return
4291
www_auth = resp.headers.get("WWW-Authenticate")
4392
if www_auth is None:
4493
logger.warning("No WWW-Authenticate header received from %s", url)
45-
raise HTTPException(**error_response.model_dump()) from chain_from
94+
raise HTTPException(**error_response.model_dump())
4695
raise HTTPException(
4796
**error_response.model_dump(),
4897
headers={"WWW-Authenticate": www_auth},
49-
) from chain_from
98+
)
5099
except (aiohttp.ClientError, TimeoutError) as probe_err:
51100
logger.warning("OAuth probe failed for %s: %s", url, probe_err)
52101
raise HTTPException(**error_response.model_dump()) from probe_err

src/utils/responses.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
ServiceUnavailableResponse,
4545
)
4646
from utils.mcp_headers import McpHeaders, extract_propagated_headers
47-
from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401
4847
from utils.prompts import get_system_prompt, get_topic_summary_system_prompt
4948
from utils.query import (
5049
extract_provider_and_model_from_model_id,
@@ -464,16 +463,6 @@ def _get_token_value(original: str, header: str) -> Optional[str]:
464463
if mcp_server.authorization_headers and len(headers) != len(
465464
mcp_server.authorization_headers
466465
):
467-
# If OAuth was required and no headers passed, probe endpoint and forward
468-
# 401 with WWW-Authenticate so the client can perform OAuth
469-
uses_oauth = (
470-
constants.MCP_AUTH_OAUTH
471-
in mcp_server.resolved_authorization_headers.values()
472-
)
473-
if uses_oauth and (
474-
mcp_headers is None or not mcp_headers.get(mcp_server.name)
475-
):
476-
await probe_mcp_oauth_and_raise_401(mcp_server.url)
477466
logger.warning(
478467
"Skipping MCP server %s: required %d auth headers but only resolved %d",
479468
mcp_server.name,

tests/e2e/features/mcp.feature

Lines changed: 5 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ Feature: MCP tests
5656
"""
5757
And The headers of the response contains the following header "www-authenticate"
5858

59-
@skip # will be fixed in LCORE-1368
6059
Scenario: Check if tools endpoint succeeds when MCP auth token is passed
6160
Given The system is in default state
6261
And I set the "MCP-HEADERS" header to
@@ -65,42 +64,9 @@ Feature: MCP tests
6564
"""
6665
When I access REST API endpoint "tools" using HTTP GET method
6766
Then The status code of the response is 200
68-
And The body of the response is the following
69-
"""
70-
{
71-
"tools": [
72-
{
73-
"identifier": "",
74-
"description": "Insert documents into memory",
75-
"parameters": [],
76-
"provider_id": "",
77-
"toolgroup_id": "builtin::rag",
78-
"server_source": "builtin",
79-
"type": ""
80-
},
81-
{
82-
"identifier": "",
83-
"description": "Search for information in a database.",
84-
"parameters": [],
85-
"provider_id": "",
86-
"toolgroup_id": "builtin::rag",
87-
"server_source": "builtin",
88-
"type": ""
89-
},
90-
{
91-
"identifier": "",
92-
"description": "Mock tool for E2E",
93-
"parameters": [],
94-
"provider_id": "",
95-
"toolgroup_id": "mcp-oauth",
96-
"server_source": "http://localhost:3001",
97-
"type": ""
98-
}
99-
]
100-
}
101-
"""
67+
And The body of the response contains mcp-oauth
10268

103-
@skip # will be fixed in LCORE-1366
69+
@skip-in-library-mode # will be fixed in LCORE-1428
10470
Scenario: Check if query endpoint succeeds when MCP auth token is passed
10571
Given The system is in default state
10672
And I set the "MCP-HEADERS" header to
@@ -115,10 +81,10 @@ Feature: MCP tests
11581
Then The status code of the response is 200
11682
And The response should contain following fragments
11783
| Fragments in LLM response |
118-
| hello |
84+
| Hello |
11985
And The token metrics should have increased
12086

121-
@skip # will be fixed in LCORE-1366
87+
@skip-in-library-mode # will be fixed in LCORE-1428
12288
Scenario: Check if streaming_query endpoint succeeds when MCP auth token is passed
12389
Given The system is in default state
12490
And I set the "MCP-HEADERS" header to
@@ -134,10 +100,9 @@ Feature: MCP tests
134100
Then The status code of the response is 200
135101
And The streamed response should contain following fragments
136102
| Fragments in LLM response |
137-
| hello |
103+
| Hello |
138104
And The token metrics should have increased
139105

140-
@skip # will be fixed in LCORE-1368
141106
Scenario: Check if tools endpoint reports error when MCP invalid auth token is passed
142107
Given The system is in default state
143108
And I set the "MCP-HEADERS" header to
@@ -180,7 +145,6 @@ Feature: MCP tests
180145
"""
181146
And The headers of the response contains the following header "www-authenticate"
182147

183-
@skip # will be fixed in LCORE-1366
184148
Scenario: Check if streaming_query endpoint reports error when MCP invalid auth token is passed
185149
Given The system is in default state
186150
And I set the "MCP-HEADERS" header to

tests/e2e/mock_mcp_server/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def do_GET(self) -> None: # pylint: disable=invalid-name
4747
"""Handle GET requests."""
4848
if self.path == "/health":
4949
self._json_response({"status": "ok"})
50+
elif self._parse_auth() is not None:
51+
self._json_response({"status": "authorized"})
5052
else:
5153
self._require_oauth()
5254

0 commit comments

Comments
 (0)