Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions lightspeed-stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ llama_stack:
# Alternative for "as library use"
# use_as_library_client: true
# library_client_config_path: <path-to-llama-stack-run.yaml-file>
url: http://llama-stack:8321
url: http://localhost:8321
api_key: xyzzy
user_data_collection:
feedback_enabled: true
Expand All @@ -35,4 +35,10 @@ authentication:
# OKP Solr for supplementary RAG
solr:
enabled: false
offline: true
offline: true

mcp_servers:
- name: "ggg-mcp-server"
url: "http://localhost:3001"
authorization_headers:
Authorization: "oauth"
23 changes: 13 additions & 10 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
extract_vector_store_ids_from_tools,
get_topic_summary,
prepare_responses_params,
responses_params_to_request_body,
)
from utils.shields import (
append_turn_to_conversation,
Expand Down Expand Up @@ -190,6 +191,17 @@ async def query_endpoint_handler(
vector_store_ids = extract_vector_store_ids_from_tools(responses_params.tools)
rag_id_mapping = configuration.rag_id_mapping

# Get topic summary for new conversation before main response so it is not
# cancelled by MCP session cleanup (MCPSessionManager.close_all) that runs
# after retrieve_response completes.
if not user_conversation and query_request.generate_topic_summary:
logger.debug("Generating topic summary for new conversation")
topic_summary = await get_topic_summary(
query_request.query, client, responses_params.model
)
else:
topic_summary = None

# Retrieve response using Responses API
turn_summary = await retrieve_response(
client,
Expand All @@ -207,15 +219,6 @@ async def query_endpoint_handler(
doc_ids_from_chunks + turn_summary.referenced_documents
)

# Get topic summary for new conversation
if not user_conversation and query_request.generate_topic_summary:
logger.debug("Generating topic summary for new conversation")
topic_summary = await get_topic_summary(
query_request.query, client, responses_params.model
)
else:
topic_summary = None

logger.info("Consuming tokens")
consume_query_tokens(
user_id=user_id,
Expand Down Expand Up @@ -301,7 +304,7 @@ async def retrieve_response( # pylint: disable=too-many-locals
)
return TurnSummary(llm_response=violation_message)
response = await client.responses.create(
**responses_params.model_dump(exclude_none=True)
**responses_params_to_request_body(responses_params),
)
response = cast(OpenAIResponseObject, response)

Expand Down
27 changes: 18 additions & 9 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from utils.quota import check_tokens_available, get_available_quotas
from utils.responses import (
build_mcp_tool_call_from_arguments_done,
responses_params_to_request_body,
build_tool_call_summary,
build_tool_result_from_mcp_output_item_done,
deduplicate_referenced_documents,
Expand Down Expand Up @@ -303,7 +304,7 @@ async def retrieve_response_generator(
)
# Retrieve response stream (may raise exceptions)
response = await context.client.responses.create(
**responses_params.model_dump(exclude_none=True)
**responses_params_to_request_body(responses_params),
)
# Store pre-RAG documents for later merging
turn_summary.pre_rag_documents = doc_ids_from_chunks
Expand Down Expand Up @@ -431,7 +432,7 @@ async def _on_interrupt() -> None:
return guard


async def generate_response(
async def generate_response( # pylint: disable=too-many-statements
generator: AsyncIterator[str],
context: ResponseGeneratorContext,
responses_params: ResponsesApiParams,
Expand Down Expand Up @@ -506,17 +507,25 @@ async def generate_response(

# Post-stream side effects: only run when streaming finished successfully

# Get topic summary for new conversations if needed
# Get topic summary for new conversations if needed. Guard against
# CancelledError from MCP session cleanup (MCPSessionManager.close_all)
# so we still yield stream_end_event and complete the ASGI response.
topic_summary = None
if not context.query_request.conversation_id:
should_generate = context.query_request.generate_topic_summary
if should_generate:
logger.debug("Generating topic summary for new conversation")
topic_summary = await get_topic_summary(
context.query_request.query,
context.client,
responses_params.model,
)
try:
logger.debug("Generating topic summary for new conversation")
topic_summary = await get_topic_summary(
context.query_request.query,
context.client,
responses_params.model,
)
except asyncio.CancelledError:
logger.debug(
"Topic summary cancelled (e.g. MCP cleanup); completing without it"
)
topic_summary = None

# Consume tokens
logger.info("Consuming tokens")
Expand Down
4 changes: 1 addition & 3 deletions src/app/endpoints/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ async def tools_endpoint_handler( # pylint: disable=too-many-locals,too-many-st
continue
except (AuthenticationError, AuthenticationRequiredError) as e:
if toolgroup.mcp_endpoint:
await probe_mcp_oauth_and_raise_401(
toolgroup.mcp_endpoint.uri, chain_from=e
)
await probe_mcp_oauth_and_raise_401(toolgroup.mcp_endpoint.uri)
error_response = UnauthorizedResponse(cause=str(e))
raise HTTPException(**error_response.model_dump()) from e
except APIConnectionError as e:
Expand Down
31 changes: 18 additions & 13 deletions src/utils/mcp_oauth_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,45 @@

async def probe_mcp_oauth_and_raise_401(
url: str,
chain_from: Optional[BaseException] = None,
authorization: Optional[str] = None,
) -> None:
"""Probe MCP endpoint and raise 401 so the client can perform OAuth.
"""Probe MCP endpoint and raise 401 only when the server responds with 401.

Performs an async GET to the given URL to obtain a WWW-Authenticate header,
then raises HTTPException with status 401 and that header. If the probe
fails (connection error, timeout), raises 401 without the header.
Performs a GET to the given URL with the optional Authorization header.
If the response status is 401, raises HTTPException with status 401 and
WWW-Authenticate header when present. Otherwise returns without raising.

Args:
url: MCP server URL to probe.
authorization: Optional Authorization header value (e.g. "Bearer <token>").
chain_from: Exception to chain the HTTPException from when
the probe succeeds (e.g. the original AuthenticationError).
the server returns 401 (e.g. the original AuthenticationError).

Returns:
None. Always raises an HTTPException.
None. Raises only when the server responds with 401.

Raises:
HTTPException: 401 with WWW-Authenticate when the probe succeeds, or
401 without the header when the probe fails.
HTTPException: 401 with WWW-Authenticate when the server returns 401.
"""
cause = f"MCP server at {url} requires OAuth"
error_response = UnauthorizedResponse(cause=cause)
headers: Optional[dict[str, str]] = (
{"Authorization": authorization} if authorization is not None else None
)
try:
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as resp:
async with session.get(url, headers=headers) as resp:
if resp.status != 401:
return
www_auth = resp.headers.get("WWW-Authenticate")
if www_auth is None:
logger.warning("No WWW-Authenticate header received from %s", url)
raise HTTPException(**error_response.model_dump()) from chain_from
raise HTTPException(**error_response.model_dump())
raise HTTPException(
**error_response.model_dump(),
headers={"WWW-Authenticate": www_auth},
) from chain_from
)
except (aiohttp.ClientError, TimeoutError) as probe_err:
logger.warning("OAuth probe failed for %s: %s", url, probe_err)
raise HTTPException(**error_response.model_dump()) from probe_err
# Only raise on 401; connection/timeout are not 401, so do not raise
47 changes: 37 additions & 10 deletions src/utils/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,33 @@
logger = get_logger(__name__)


def responses_params_to_request_body(params: ResponsesApiParams) -> dict[str, Any]:
"""Build request body for Responses API from ResponsesApiParams.

Serializes params and ensures MCP tool authorization is included (the
llama_stack_api marks it Field(exclude=True), so it is omitted by
model_dump() otherwise).

Parameters:
params: The Responses API parameters.

Returns:
Dict suitable for client.responses.create(**result).
"""
body = params.model_dump(exclude_none=True)
tools = getattr(params, "tools", None)
if tools is not None:
tools_out: list[dict[str, Any]] = []
for tool in tools:
tool_dump = tool.model_dump(exclude_none=True)
auth = getattr(tool, "authorization", None)
if auth is not None:
tool_dump["authorization"] = auth
tools_out.append(tool_dump)
body["tools"] = tools_out
return body


async def get_topic_summary(
question: str, client: AsyncLlamaStackClient, model_id: str
) -> str:
Expand Down Expand Up @@ -391,20 +418,20 @@ def _get_token_value(original: str, header: str) -> Optional[str]:
if h_value is not None:
headers[name] = h_value

uses_oauth = (
constants.MCP_AUTH_OAUTH
in mcp_server.resolved_authorization_headers.values()
)

if uses_oauth:
await probe_mcp_oauth_and_raise_401(
mcp_server.url, authorization=headers.get("Authorization", None)
)

# Skip server if auth headers were configured but not all could be resolved
if mcp_server.authorization_headers and len(headers) != len(
mcp_server.authorization_headers
):
# If OAuth was required and no headers passed, probe endpoint and forward
# 401 with WWW-Authenticate so the client can perform OAuth
uses_oauth = (
constants.MCP_AUTH_OAUTH
in mcp_server.resolved_authorization_headers.values()
)
if uses_oauth and (
mcp_headers is None or not mcp_headers.get(mcp_server.name)
):
await probe_mcp_oauth_and_raise_401(mcp_server.url)
logger.warning(
"Skipping MCP server %s: required %d auth headers but only resolved %d",
mcp_server.name,
Expand Down
Loading
Loading