Skip to content

Commit 5f32b04

Browse files
committed
Send x-llamastack-provider-data header with MCP auth to llama-stack
1 parent 3b135fb commit 5f32b04

3 files changed

Lines changed: 144 additions & 0 deletions

File tree

src/utils/responses.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,37 @@ async def prepare_tools(
198198
return toolgroups
199199

200200

201+
def _build_provider_data_headers(
202+
tools: Optional[list[dict[str, Any]]],
203+
) -> Optional[dict[str, str]]:
204+
"""Build extra HTTP headers containing MCP provider data for Llama Stack.
205+
206+
Extracts per-server auth headers from MCP tool definitions and encodes
207+
them as a JSON ``x-llamastack-provider-data`` header that Llama Stack
208+
uses to authenticate with downstream MCP servers.
209+
210+
Args:
211+
tools: Prepared tool definitions (may include MCP and non-MCP tools).
212+
213+
Returns:
214+
Dict with a single ``x-llamastack-provider-data`` key, or None when
215+
no MCP tools carry headers.
216+
"""
217+
if not tools:
218+
return None
219+
220+
mcp_headers: McpHeaders = {
221+
tool["server_url"]: tool["headers"]
222+
for tool in tools
223+
if tool.get("type") == "mcp" and tool.get("headers") and tool.get("server_url")
224+
}
225+
226+
if not mcp_headers:
227+
return None
228+
229+
return {"x-llamastack-provider-data": json.dumps({"mcp_headers": mcp_headers})}
230+
231+
201232
async def prepare_responses_params( # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
202233
client: AsyncLlamaStackClient,
203234
query_request: QueryRequest,
@@ -281,6 +312,9 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
281312
llama_stack_conv_id,
282313
)
283314

315+
# Build x-llamastack-provider-data header from MCP tool headers
316+
extra_headers = _build_provider_data_headers(tools)
317+
284318
return ResponsesApiParams(
285319
input=input_text,
286320
model=llama_stack_model_id,
@@ -289,6 +323,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
289323
conversation=llama_stack_conv_id,
290324
stream=stream,
291325
store=store,
326+
extra_headers=extra_headers,
292327
)
293328

294329

src/utils/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class ResponsesApiParams(BaseModel):
123123
conversation: str = Field(description="The conversation ID in llama-stack format")
124124
stream: bool = Field(description="Whether to stream the response")
125125
store: bool = Field(description="Whether to store the response")
126+
extra_headers: Optional[dict[str, str]] = Field(
127+
default=None,
128+
description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)",
129+
)
126130

127131

128132
class ToolCallSummary(BaseModel):

tests/unit/utils/test_responses.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,111 @@ async def test_prepare_responses_params_api_status_error_on_models(
874874
await prepare_responses_params(mock_client, query_request, None, "token")
875875
assert exc_info.value.status_code == 500
876876

877+
@pytest.mark.asyncio
878+
async def test_prepare_responses_params_includes_mcp_provider_data_headers(
879+
self, mocker: MockerFixture
880+
) -> None:
881+
"""Test that extra_headers with x-llamastack-provider-data is set when MCP tools have headers."""
882+
mock_client = mocker.AsyncMock()
883+
mock_model = mocker.Mock()
884+
mock_model.id = "provider1/model1"
885+
mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"}
886+
mock_client.models.list = mocker.AsyncMock(return_value=[mock_model])
887+
888+
mock_conversation = mocker.Mock()
889+
mock_conversation.id = "new_conv_id"
890+
mock_client.conversations.create = mocker.AsyncMock(
891+
return_value=mock_conversation
892+
)
893+
894+
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
895+
896+
# Simulate MCP tools with headers (as returned by prepare_tools/get_mcp_tools)
897+
mcp_tools_with_headers = [
898+
{
899+
"type": "mcp",
900+
"server_label": "mcp::aap-controller",
901+
"server_url": "http://aap.foo.redhat.com:8004/sse",
902+
"require_approval": "never",
903+
"headers": {"X-Authorization": "client-token"},
904+
},
905+
{
906+
"type": "mcp",
907+
"server_label": "mcp::aap-lightspeed",
908+
"server_url": "http://aap.foo.redhat.com:8005/sse",
909+
"require_approval": "never",
910+
"headers": {"X-Authorization": "client-token-2"},
911+
},
912+
]
913+
914+
mocker.patch("utils.responses.configuration", mocker.Mock())
915+
mocker.patch(
916+
"utils.responses.select_model_and_provider_id",
917+
return_value=("provider1/model1", "model1", "provider1"),
918+
)
919+
mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None))
920+
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
921+
mocker.patch(
922+
"utils.responses.prepare_tools", return_value=mcp_tools_with_headers
923+
)
924+
mocker.patch("utils.responses.prepare_input", return_value="test")
925+
926+
result = await prepare_responses_params(
927+
mock_client, query_request, None, "token"
928+
)
929+
930+
# The result should contain extra_headers with x-llamastack-provider-data
931+
dumped = result.model_dump()
932+
assert (
933+
dumped["extra_headers"] is not None
934+
), "extra_headers should not be None when MCP tools have headers"
935+
assert "x-llamastack-provider-data" in dumped["extra_headers"]
936+
937+
provider_data = json.loads(
938+
dumped["extra_headers"]["x-llamastack-provider-data"]
939+
)
940+
assert "mcp_headers" in provider_data
941+
assert provider_data["mcp_headers"] == {
942+
"http://aap.foo.redhat.com:8004/sse": {"X-Authorization": "client-token"},
943+
"http://aap.foo.redhat.com:8005/sse": {"X-Authorization": "client-token-2"},
944+
}
945+
946+
@pytest.mark.asyncio
947+
async def test_prepare_responses_params_no_extra_headers_without_mcp_tools(
948+
self, mocker: MockerFixture
949+
) -> None:
950+
"""Test that extra_headers is None when no MCP tools have headers."""
951+
mock_client = mocker.AsyncMock()
952+
mock_model = mocker.Mock()
953+
mock_model.id = "provider1/model1"
954+
mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"}
955+
mock_client.models.list = mocker.AsyncMock(return_value=[mock_model])
956+
957+
mock_conversation = mocker.Mock()
958+
mock_conversation.id = "new_conv_id"
959+
mock_client.conversations.create = mocker.AsyncMock(
960+
return_value=mock_conversation
961+
)
962+
963+
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
964+
965+
mocker.patch("utils.responses.configuration", mocker.Mock())
966+
mocker.patch(
967+
"utils.responses.select_model_and_provider_id",
968+
return_value=("provider1/model1", "model1", "provider1"),
969+
)
970+
mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None))
971+
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
972+
mocker.patch("utils.responses.prepare_tools", return_value=None)
973+
mocker.patch("utils.responses.prepare_input", return_value="test")
974+
975+
result = await prepare_responses_params(
976+
mock_client, query_request, None, "token"
977+
)
978+
979+
dumped = result.model_dump()
980+
assert dumped.get("extra_headers") is None
981+
877982
@pytest.mark.asyncio
878983
async def test_prepare_responses_params_api_status_error_on_conversation(
879984
self, mocker: MockerFixture

0 commit comments

Comments
 (0)