Skip to content

Commit c13c8c4

Browse files
authored
Merge pull request #1143 from tisnik/lcore-1173
LCORE-1173: named type for MCP headers complicated structure
2 parents 5d1416c + e52b9b3 commit c13c8c4

5 files changed

Lines changed: 19 additions & 20 deletions

File tree

src/app/endpoints/a2a.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from configuration import configuration
4545
from models.config import Action
4646
from models.requests import QueryRequest
47-
from utils.mcp_headers import mcp_headers_dependency
47+
from utils.mcp_headers import mcp_headers_dependency, McpHeaders
4848
from utils.responses import (
4949
extract_text_from_response_output_item,
5050
prepare_responses_params,
@@ -183,17 +183,15 @@ class A2AAgentExecutor(AgentExecutor):
183183
routing queries to the LLM backend using the Responses API.
184184
"""
185185

186-
def __init__(
187-
self, auth_token: str, mcp_headers: Optional[dict[str, dict[str, str]]] = None
188-
):
186+
def __init__(self, auth_token: str, mcp_headers: Optional[McpHeaders] = None):
189187
"""Initialize the A2A agent executor.
190188
191189
Args:
192190
auth_token: Authentication token for the request
193191
mcp_headers: MCP headers for context propagation
194192
"""
195193
self.auth_token: str = auth_token
196-
self.mcp_headers: dict[str, dict[str, str]] = mcp_headers or {}
194+
self.mcp_headers: McpHeaders = mcp_headers or {}
197195

198196
async def execute(
199197
self,
@@ -648,9 +646,7 @@ async def get_agent_card( # pylint: disable=unused-argument
648646
raise
649647

650648

651-
async def _create_a2a_app(
652-
auth_token: str, mcp_headers: dict[str, dict[str, str]]
653-
) -> Any:
649+
async def _create_a2a_app(auth_token: str, mcp_headers: McpHeaders) -> Any:
654650
"""Create an A2A Starlette application instance with auth context.
655651
656652
Args:
@@ -681,7 +677,7 @@ async def _create_a2a_app(
681677
async def handle_a2a_jsonrpc( # pylint: disable=too-many-locals,too-many-statements
682678
request: Request,
683679
auth: Annotated[AuthTuple, Depends(auth_dependency)],
684-
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
680+
mcp_headers: McpHeaders = Depends(mcp_headers_dependency),
685681
) -> Response | StreamingResponse:
686682
"""
687683
Handle A2A JSON-RPC requests following the A2A protocol specification.

src/app/endpoints/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
check_configuration_loaded,
4141
validate_and_retrieve_conversation,
4242
)
43-
from utils.mcp_headers import mcp_headers_dependency
43+
from utils.mcp_headers import mcp_headers_dependency, McpHeaders
4444
from utils.query import (
4545
consume_query_tokens,
4646
handle_known_apistatus_errors,
@@ -93,7 +93,7 @@ async def query_endpoint_handler(
9393
request: Request,
9494
query_request: QueryRequest,
9595
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
96-
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
96+
mcp_headers: McpHeaders = Depends(mcp_headers_dependency),
9797
) -> QueryResponse:
9898
"""
9999
Handle request to the /query endpoint using Responses API.

src/app/endpoints/streaming_query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
check_configuration_loaded,
5858
validate_and_retrieve_conversation,
5959
)
60-
from utils.mcp_headers import mcp_headers_dependency
60+
from utils.mcp_headers import mcp_headers_dependency, McpHeaders
6161
from utils.query import (
6262
consume_query_tokens,
6363
extract_provider_and_model_from_model_id,
@@ -118,7 +118,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
118118
request: Request,
119119
query_request: QueryRequest,
120120
auth: Annotated[AuthTuple, Depends(get_auth_dependency())],
121-
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
121+
mcp_headers: McpHeaders = Depends(mcp_headers_dependency),
122122
) -> StreamingResponse:
123123
"""
124124
Handle request to the /streaming_query endpoint using Responses API.

src/utils/mcp_headers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
logger = logging.getLogger("app.endpoints.dependencies")
1212

13+
type McpHeaders = dict[str, dict[str, str]]
1314

14-
async def mcp_headers_dependency(request: Request) -> dict[str, dict[str, str]]:
15+
16+
async def mcp_headers_dependency(request: Request) -> McpHeaders:
1517
"""Get the MCP headers dependency to passed to mcp servers.
1618
1719
mcp headers is a json dictionary or mcp url paths and their respective headers
@@ -25,7 +27,7 @@ async def mcp_headers_dependency(request: Request) -> dict[str, dict[str, str]]:
2527
return extract_mcp_headers(request)
2628

2729

28-
def extract_mcp_headers(request: Request) -> dict[str, dict[str, str]]:
30+
def extract_mcp_headers(request: Request) -> McpHeaders:
2931
"""Extract mcp headers from MCP-HEADERS header.
3032
3133
If the header is missing, contains invalid JSON, or the decoded
@@ -56,8 +58,8 @@ def extract_mcp_headers(request: Request) -> dict[str, dict[str, str]]:
5658

5759

5860
def handle_mcp_headers_with_toolgroups(
59-
mcp_headers: dict[str, dict[str, str]], config: AppConfig
60-
) -> dict[str, dict[str, str]]:
61+
mcp_headers: McpHeaders, config: AppConfig
62+
) -> McpHeaders:
6163
"""Process MCP headers by converting toolgroup names to URLs.
6264
6365
This function takes MCP headers where keys can be either valid URLs or

src/utils/responses.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
prepare_input,
3838
select_model_and_provider_id,
3939
)
40+
from utils.mcp_headers import McpHeaders
4041
from utils.suid import to_llama_stack_conversation_id
4142
from utils.token_counter import TokenCounter
4243
from utils.types import (
@@ -141,7 +142,7 @@ async def prepare_tools(
141142
query_request: QueryRequest,
142143
token: str,
143144
config: AppConfig,
144-
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
145+
mcp_headers: Optional[McpHeaders] = None,
145146
) -> Optional[list[dict[str, Any]]]:
146147
"""Prepare tools for Responses API including RAG and MCP tools.
147148
@@ -202,7 +203,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
202203
query_request: QueryRequest,
203204
user_conversation: Optional[UserConversation],
204205
token: str,
205-
mcp_headers: Optional[dict[str, dict[str, str]]] = None,
206+
mcp_headers: Optional[McpHeaders] = None,
206207
stream: bool = False,
207208
store: bool = True,
208209
) -> ResponsesApiParams:
@@ -315,7 +316,7 @@ def get_rag_tools(vector_store_ids: list[str]) -> Optional[list[dict[str, Any]]]
315316
def get_mcp_tools(
316317
mcp_servers: list[ModelContextProtocolServer],
317318
token: str | None = None,
318-
mcp_headers: dict[str, dict[str, str]] | None = None,
319+
mcp_headers: Optional[McpHeaders] = None,
319320
) -> list[dict[str, Any]]:
320321
"""Convert MCP servers to tools format for Responses API.
321322

0 commit comments

Comments
 (0)