Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
validate_and_retrieve_conversation,
)
from utils.mcp_headers import McpHeaders, mcp_headers_dependency
from utils.mcp_oauth_probe import check_mcp_auth
from utils.query import (
consume_query_tokens,
handle_known_apistatus_errors,
Expand Down Expand Up @@ -122,6 +123,8 @@ async def query_endpoint_handler(
"""
check_configuration_loaded(configuration)

await check_mcp_auth(configuration, mcp_headers)

Comment on lines +126 to +127

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Inspect cancellation handling vs. response/topic-summary call sites.
rg -n "responses\.create|get_topic_summary|CancelledError" src/app/endpoints/query.py src/app/endpoints/streaming_query.py

Repository: lightspeed-core/lightspeed-stack

Length of output: 932


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Read query.py around the critical lines
echo "=== query.py around line 126-127 (auth check) ==="
sed -n '120,135p' src/app/endpoints/query.py

echo ""
echo "=== query.py around line 222 (get_topic_summary call) ==="
sed -n '200,240p' src/app/endpoints/query.py

echo ""
echo "=== query.py around line 318 (responses.create call) ==="
sed -n '300,330p' src/app/endpoints/query.py

Repository: lightspeed-core/lightspeed-stack

Length of output: 3859


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the function definition and overall structure
echo "=== query.py function signature and structure (lines 1-150) ==="
sed -n '80,150p' src/app/endpoints/query.py

echo ""
echo "=== Check for outer try-except blocks in query handler ==="
rg -n "try:|except|CancelledError" src/app/endpoints/query.py | head -20

Repository: lightspeed-core/lightspeed-stack

Length of output: 3403


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# See lines 200-350 to understand exception handling scope
echo "=== query.py lines 200-350 (full context) ==="
sed -n '200,350p' src/app/endpoints/query.py | cat -n

echo ""
echo "=== streaming_query.py exception handling around line 490 ==="
sed -n '480,520p' src/app/endpoints/streaming_query.py | cat -n

Repository: lightspeed-core/lightspeed-stack

Length of output: 8827


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Check get_topic_summary implementation
echo "=== get_topic_summary implementation ==="
rg -A 15 "^async def get_topic_summary" src/

Repository: lightspeed-core/lightspeed-stack

Length of output: 1031


🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Get the full get_topic_summary implementation
echo "=== get_topic_summary full implementation ==="
rg -A 40 "^async def get_topic_summary" src/utils/responses.py

Repository: lightspeed-core/lightspeed-stack

Length of output: 1623


Add CancelledError handling to non-streaming query path to prevent 500 errors during request cancellation.

The non-streaming query endpoint lacks explicit cancellation handling for get_topic_summary() (line 222) and responses.create() (within retrieve_response()). If the client cancels the request during either call, the unhandled CancelledError propagates to the ASGI layer, resulting in a 500 error. The streaming endpoint handles this correctly with an explicit except asyncio.CancelledError: handler (see streaming_query.py:490), but the non-streaming path needs equivalent protection to prevent 500 responses when authenticated MCP requests are interrupted.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/app/endpoints/query.py` around lines 126 - 127, The non-streaming query
endpoint must catch asyncio.CancelledError around the blocking calls to
get_topic_summary(...) and the call chain that invokes responses.create(...)
inside retrieve_response(...) to avoid 500s on client cancellation; update the
handler in query.py to wrap those await calls (or the higher-level
retrieve_response call) with an except asyncio.CancelledError: block that logs
or silently returns an appropriate response/cleanup and re-raises or returns
gracefully (mirroring the streaming handler pattern), ensuring you reference
get_topic_summary, retrieve_response, and responses.create to locate the spots
to add the cancellation handling.

started_at = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%dT%H:%M:%SZ")
user_id, _, _skip_userid_check, token = auth
# Check token availability
Expand Down
3 changes: 3 additions & 0 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
validate_and_retrieve_conversation,
)
from utils.mcp_headers import McpHeaders, mcp_headers_dependency
from utils.mcp_oauth_probe import check_mcp_auth
from utils.query import (
consume_query_tokens,
extract_provider_and_model_from_model_id,
Expand Down Expand Up @@ -151,6 +152,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
"""
check_configuration_loaded(configuration)

await check_mcp_auth(configuration, mcp_headers)

user_id, _user_name, _skip_userid_check, token = auth
started_at = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%dT%H:%M:%SZ")

Expand Down
15 changes: 5 additions & 10 deletions src/app/endpoints/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException, Request
from llama_stack_client import APIConnectionError, BadRequestError, AuthenticationError
from llama_stack.core.datatypes import AuthenticationRequiredError
from llama_stack_client import APIConnectionError, BadRequestError

from authentication import get_auth_dependency
from authentication.interface import AuthTuple
Expand All @@ -21,7 +20,7 @@
)
from utils.endpoints import check_configuration_loaded
from utils.mcp_headers import McpHeaders, mcp_headers_dependency
from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401
from utils.mcp_oauth_probe import check_mcp_auth
from utils.tool_formatter import format_tools_list
from log import get_logger

Expand Down Expand Up @@ -124,6 +123,8 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st

check_configuration_loaded(configuration)

await check_mcp_auth(configuration, mcp_headers)

toolgroups_response = []
try:
client = AsyncLlamaStackClientHolder().get_client()
Expand All @@ -146,6 +147,7 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
# Get tools for each toolgroup
headers = mcp_headers.get(toolgroup.identifier, {})
authorization = headers.pop("Authorization", None)

Comment on lines 148 to +150

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Avoid mutating request-scoped MCP headers in-place.

Line 149 uses pop, which mutates the dict sourced from mcp_headers. Use a local copy/extraction so downstream code sees stable input.

♻️ Proposed fix
-            headers = mcp_headers.get(toolgroup.identifier, {})
-            authorization = headers.pop("Authorization", None)
+            original_headers = mcp_headers.get(toolgroup.identifier, {})
+            authorization = next(
+                (v for k, v in original_headers.items() if k.lower() == "authorization"),
+                None,
+            )
+            headers = {
+                k: v for k, v in original_headers.items() if k.lower() != "authorization"
+            }

As per coding guidelines "Avoid in-place parameter modification anti-patterns; return new data structures instead."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/app/endpoints/tools.py` around lines 148 - 150, The code mutates the
request-scoped mcp_headers by calling headers =
mcp_headers.get(toolgroup.identifier, {}) followed by
headers.pop("Authorization", None); change this to operate on a local copy:
retrieve the dict via mcp_headers.get(toolgroup.identifier) (or {}), create a
shallow copy (e.g., dict(headers) or headers.copy()) into a new variable (e.g.,
local_headers), then extract the Authorization value from that copy into
authorization without modifying the original mcp_headers; update subsequent uses
to reference local_headers instead of headers (look for occurrences around
toolgroup.identifier, headers, and authorization in src/app/endpoints/tools.py).

tools_response = await client.tools.list(
toolgroup_id=toolgroup.identifier,
extra_headers=headers,
Expand All @@ -154,13 +156,6 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
except BadRequestError:
logger.error("Toolgroup %s is not found", toolgroup.identifier)
continue
except (AuthenticationError, AuthenticationRequiredError) as e:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we handled auth from llamastack itself as well, not just from MCP Oath. Now the check_mcp_auth only covers the MCP Oauth scenario. So if client.tools.list() now gives AuthenticationError for any problem that is not MCP-related, it now puts forward a 500 error instead of 401.

A simple fix could be: re-add catch for AuthenticationError (to return 401 again) -- the deleted lines 157, 162, and 163:

except AuthenticationError as e:
error_response = UnauthorizedResponse(cause=str(e))
raise HTTPException(**error_response.model_dump()) from e

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason I got rid of the llama stack authentication error catching in /tools is because it results in an asymmetrical implementation between tools and the query endpoints. I can not think of a scenario where client.tools.list() would return a 401 that has not already been caught (because only mcp servers are able to be authd against in llamastack)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I might be able to add an exception handler in the handle_known_apistatus_errors() and extend that to be used in /tools wdyt?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I might be able to add an exception handler in the handle_known_apistatus_errors() and extend that to be used in /tools wdyt?

But should I add this if its basically an impossible to reach branch?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that currently only MCP servers trigger auth in llamastack. OTOH, it might be worth it as defense-in-depth -- if llamastack adds auth for non-MCP reason we would get 500 errors again. It's a low risk though, so up to you 🤷

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ended up not implementing this error catching mainly to reduce size of this PR as well as not add any (as of right now) unneeded complexity. If this becomes an issue we will address it at that time :). However with the toolgroup depreciation large changes will likely begin to emerge that will likely make any work done with this exception catching irrelevant so for the time being I don't think its necessary.

if toolgroup.mcp_endpoint:
await probe_mcp_oauth_and_raise_401(
toolgroup.mcp_endpoint.uri, chain_from=e
)
error_response = UnauthorizedResponse(cause=str(e))
raise HTTPException(**error_response.model_dump()) from e
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
response = ServiceUnavailableResponse(
Expand Down
79 changes: 64 additions & 15 deletions src/utils/mcp_oauth_probe.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,101 @@
"""Probe MCP server for OAuth and raise 401 with WWW-Authenticate when required."""
"""Probe MCP servers for OAuth and raise 401 with WWW-Authenticate when required.

Used by endpoints that call MCP-backed services so clients receive a proper
401 with WWW-Authenticate when an MCP server requires OAuth.
"""

import asyncio
from typing import Optional

import aiohttp
from fastapi import HTTPException

from models.responses import UnauthorizedResponse

from configuration import AppConfig
from utils.mcp_headers import McpHeaders
import constants

from log import get_logger

logger = get_logger(__name__)


async def probe_mcp_oauth_and_raise_401(
async def check_mcp_auth(configuration: AppConfig, mcp_headers: McpHeaders) -> None:
"""Probe each configured MCP server that expects OAuth or has auth headers.

For every MCP server that has an Authorization header in mcp_headers or
has OAuth in its resolved_authorization_headers, performs a probe request.
If the server indicates OAuth is required, raises 401 with
WWW-Authenticate (or 401 without header on probe failure).

Parameters:
configuration: Application config containing mcp_servers.
mcp_headers: Per-server headers; keys are MCP server names.

Returns:
None when no server requires OAuth or probe does not trigger 401.

Raises:
HTTPException: 401 when an MCP server requires OAuth (from probe_mcp).
"""
probes = []
for mcp_server in configuration.mcp_servers:
Comment thread
jrobertboos marked this conversation as resolved.
headers = mcp_headers.get(mcp_server.name, {})
authorization = headers.get("Authorization", None)
Comment on lines +44 to +45

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Handle Authorization lookup case-insensitively.

Line 43 only checks "Authorization" with exact casing. If MCP-HEADERS sends "authorization", the token is ignored and OAuth probing can incorrectly return 401.

♻️ Proposed fix
-        authorization = headers.get("Authorization", None)
+        authorization = next(
+            (value for key, value in headers.items() if key.lower() == "authorization"),
+            None,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/utils/mcp_oauth_probe.py` around lines 42 - 43, The code currently
retrieves the Authorization token using headers.get("Authorization") which is
case-sensitive; update the lookup in mcp_oauth_probe.py to be case-insensitive
by normalizing header keys (e.g., build a temporary dict with lowercased keys
from headers and then use lowercased "authorization") or check both
"Authorization" and "authorization" before falling back to None so the token is
found regardless of header casing for mcp_headers, headers, authorization, and
mcp_server.name.

if (
authorization

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will pass with empty header.... maybe authorization is not None

or constants.MCP_AUTH_OAUTH
in mcp_server.resolved_authorization_headers.values()
):
probes.append(probe_mcp(mcp_server.url, authorization=authorization))
if probes:
await asyncio.gather(*probes)


async def probe_mcp(
url: str,
chain_from: Optional[BaseException] = None,
authorization: Optional[str] = None,
) -> None:
"""Probe MCP endpoint and raise 401 so the client can perform OAuth.

Performs an async GET to the given URL to obtain a WWW-Authenticate header,
then raises HTTPException with status 401 and that header. If the probe
fails (connection error, timeout), raises 401 without the header.
Performs an async GET to the given URL. If the response is 401 with
WWW-Authenticate, raises HTTPException with that header. If the response
is 401 without the header, or the probe fails (connection error, timeout),
raises 401 without WWW-Authenticate.

Args:
Parameters:
url: MCP server URL to probe.
chain_from: Exception to chain the HTTPException from when
the probe succeeds (e.g. the original AuthenticationError).
authorization: Optional Authorization header value for the probe request.

Returns:
None. Always raises an HTTPException.
None when the server responds with a status other than 401 (OAuth not
required). Otherwise does not return; raises HTTPException.

Raises:
HTTPException: 401 with WWW-Authenticate when the probe succeeds, or
401 without the header when the probe fails.
HTTPException: 401 with WWW-Authenticate when the server returns 401
and includes that header; 401 without the header when the server
returns 401 without it or when the probe fails (timeout/connection).
"""
cause = f"MCP server at {url} requires OAuth"
error_response = UnauthorizedResponse(cause=cause)
headers: Optional[dict[str, str]] = (
{"authorization": authorization} if authorization is not None else None
)
try:
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as resp:
async with session.get(url, headers=headers) as resp:
if resp.status != 401:
return
www_auth = resp.headers.get("WWW-Authenticate")
if www_auth is None:
logger.warning("No WWW-Authenticate header received from %s", url)
raise HTTPException(**error_response.model_dump()) from chain_from
raise HTTPException(**error_response.model_dump())
raise HTTPException(
**error_response.model_dump(),
headers={"WWW-Authenticate": www_auth},
) from chain_from
)
except (aiohttp.ClientError, TimeoutError) as probe_err:
logger.warning("OAuth probe failed for %s: %s", url, probe_err)
raise HTTPException(**error_response.model_dump()) from probe_err
11 changes: 0 additions & 11 deletions src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
ServiceUnavailableResponse,
)
from utils.mcp_headers import McpHeaders, extract_propagated_headers
from utils.mcp_oauth_probe import probe_mcp_oauth_and_raise_401
from utils.prompts import get_system_prompt, get_topic_summary_system_prompt
from utils.query import (
extract_provider_and_model_from_model_id,
Expand Down Expand Up @@ -464,16 +463,6 @@ def _get_token_value(original: str, header: str) -> Optional[str]:
if mcp_server.authorization_headers and len(headers) != len(
mcp_server.authorization_headers
):
# If OAuth was required and no headers passed, probe endpoint and forward
# 401 with WWW-Authenticate so the client can perform OAuth
uses_oauth = (
constants.MCP_AUTH_OAUTH
in mcp_server.resolved_authorization_headers.values()
)
if uses_oauth and (
mcp_headers is None or not mcp_headers.get(mcp_server.name)
):
await probe_mcp_oauth_and_raise_401(mcp_server.url)
logger.warning(
"Skipping MCP server %s: required %d auth headers but only resolved %d",
mcp_server.name,
Expand Down
46 changes: 5 additions & 41 deletions tests/e2e/features/mcp.feature
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ Feature: MCP tests
"""
And The headers of the response contains the following header "www-authenticate"

@skip # will be fixed in LCORE-1368
Scenario: Check if tools endpoint succeeds when MCP auth token is passed
Given The system is in default state
And I set the "MCP-HEADERS" header to
Expand All @@ -65,42 +64,9 @@ Feature: MCP tests
"""
When I access REST API endpoint "tools" using HTTP GET method
Then The status code of the response is 200
And The body of the response is the following
"""
{
"tools": [
{
"identifier": "",
"description": "Insert documents into memory",
"parameters": [],
"provider_id": "",
"toolgroup_id": "builtin::rag",
"server_source": "builtin",
"type": ""
},
{
"identifier": "",
"description": "Search for information in a database.",
"parameters": [],
"provider_id": "",
"toolgroup_id": "builtin::rag",
"server_source": "builtin",
"type": ""
},
{
"identifier": "",
"description": "Mock tool for E2E",
"parameters": [],
"provider_id": "",
"toolgroup_id": "mcp-oauth",
"server_source": "http://localhost:3001",
"type": ""
}
]
}
"""
And The body of the response contains mcp-oauth

@skip # will be fixed in LCORE-1366
@skip-in-library-mode # will be fixed in LCORE-1428
Scenario: Check if query endpoint succeeds when MCP auth token is passed
Given The system is in default state
And I set the "MCP-HEADERS" header to
Expand All @@ -115,10 +81,10 @@ Feature: MCP tests
Then The status code of the response is 200
And The response should contain following fragments
| Fragments in LLM response |
| hello |
| Hello |
And The token metrics should have increased

@skip # will be fixed in LCORE-1366
@skip-in-library-mode # will be fixed in LCORE-1428
Scenario: Check if streaming_query endpoint succeeds when MCP auth token is passed
Given The system is in default state
And I set the "MCP-HEADERS" header to
Expand All @@ -134,10 +100,9 @@ Feature: MCP tests
Then The status code of the response is 200
And The streamed response should contain following fragments
| Fragments in LLM response |
| hello |
| Hello |
And The token metrics should have increased

@skip # will be fixed in LCORE-1368
Scenario: Check if tools endpoint reports error when MCP invalid auth token is passed
Given The system is in default state
And I set the "MCP-HEADERS" header to
Expand Down Expand Up @@ -180,7 +145,6 @@ Feature: MCP tests
"""
And The headers of the response contains the following header "www-authenticate"

@skip # will be fixed in LCORE-1366
Scenario: Check if streaming_query endpoint reports error when MCP invalid auth token is passed
Given The system is in default state
And I set the "MCP-HEADERS" header to
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/mock_mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def do_GET(self) -> None: # pylint: disable=invalid-name
"""Handle GET requests."""
if self.path == "/health":
self._json_response({"status": "ok"})
elif self._parse_auth() is not None:
self._json_response({"status": "authorized"})
else:
self._require_oauth()

Expand Down
Loading
Loading