Skip to content

Commit 268f87b

Browse files
max-svistunovtisnik
authored andcommitted
Send x-llamastack-provider-data header with MCP auth to llama-stack
1 parent cb91bf3 commit 268f87b

3 files changed

Lines changed: 144 additions & 0 deletions

File tree

src/utils/responses.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,37 @@ async def prepare_tools(
158158
return toolgroups
159159

160160

161+
def _build_provider_data_headers(
162+
tools: Optional[list[dict[str, Any]]],
163+
) -> Optional[dict[str, str]]:
164+
"""Build extra HTTP headers containing MCP provider data for Llama Stack.
165+
166+
Extracts per-server auth headers from MCP tool definitions and encodes
167+
them as a JSON ``x-llamastack-provider-data`` header that Llama Stack
168+
uses to authenticate with downstream MCP servers.
169+
170+
Args:
171+
tools: Prepared tool definitions (may include MCP and non-MCP tools).
172+
173+
Returns:
174+
Dict with a single ``x-llamastack-provider-data`` key, or None when
175+
no MCP tools carry headers.
176+
"""
177+
if not tools:
178+
return None
179+
180+
mcp_headers: McpHeaders = {
181+
tool["server_url"]: tool["headers"]
182+
for tool in tools
183+
if tool.get("type") == "mcp" and tool.get("headers") and tool.get("server_url")
184+
}
185+
186+
if not mcp_headers:
187+
return None
188+
189+
return {"x-llamastack-provider-data": json.dumps({"mcp_headers": mcp_headers})}
190+
191+
161192
async def prepare_responses_params( # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments
162193
client: AsyncLlamaStackClient,
163194
query_request: QueryRequest,
@@ -234,6 +265,9 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
234265
llama_stack_conv_id,
235266
)
236267

268+
# Build x-llamastack-provider-data header from MCP tool headers
269+
extra_headers = _build_provider_data_headers(tools)
270+
237271
return ResponsesApiParams(
238272
input=input_text,
239273
model=model,
@@ -242,6 +276,7 @@ async def prepare_responses_params( # pylint: disable=too-many-arguments,too-ma
242276
conversation=llama_stack_conv_id,
243277
stream=stream,
244278
store=store,
279+
extra_headers=extra_headers,
245280
)
246281

247282

src/utils/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ class ResponsesApiParams(BaseModel):
123123
conversation: str = Field(description="The conversation ID in llama-stack format")
124124
stream: bool = Field(description="Whether to stream the response")
125125
store: bool = Field(description="Whether to store the response")
126+
extra_headers: Optional[dict[str, str]] = Field(
127+
default=None,
128+
description="Extra HTTP headers to send with the request (e.g. x-llamastack-provider-data)",
129+
)
126130

127131

128132
class ToolCallSummary(BaseModel):

tests/unit/utils/test_responses.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -946,6 +946,111 @@ async def test_prepare_responses_params_api_status_error_on_models(
946946
await prepare_responses_params(mock_client, query_request, None, "token")
947947
assert exc_info.value.status_code == 500
948948

949+
@pytest.mark.asyncio
950+
async def test_prepare_responses_params_includes_mcp_provider_data_headers(
951+
self, mocker: MockerFixture
952+
) -> None:
953+
"""Test that extra_headers with x-llamastack-provider-data is set when MCP tools have headers."""
954+
mock_client = mocker.AsyncMock()
955+
mock_model = mocker.Mock()
956+
mock_model.id = "provider1/model1"
957+
mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"}
958+
mock_client.models.list = mocker.AsyncMock(return_value=[mock_model])
959+
960+
mock_conversation = mocker.Mock()
961+
mock_conversation.id = "new_conv_id"
962+
mock_client.conversations.create = mocker.AsyncMock(
963+
return_value=mock_conversation
964+
)
965+
966+
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
967+
968+
# Simulate MCP tools with headers (as returned by prepare_tools/get_mcp_tools)
969+
mcp_tools_with_headers = [
970+
{
971+
"type": "mcp",
972+
"server_label": "mcp::aap-controller",
973+
"server_url": "http://aap.foo.redhat.com:8004/sse",
974+
"require_approval": "never",
975+
"headers": {"X-Authorization": "client-token"},
976+
},
977+
{
978+
"type": "mcp",
979+
"server_label": "mcp::aap-lightspeed",
980+
"server_url": "http://aap.foo.redhat.com:8005/sse",
981+
"require_approval": "never",
982+
"headers": {"X-Authorization": "client-token-2"},
983+
},
984+
]
985+
986+
mocker.patch("utils.responses.configuration", mocker.Mock())
987+
mocker.patch(
988+
"utils.responses.select_model_and_provider_id",
989+
return_value=("provider1/model1", "model1", "provider1"),
990+
)
991+
mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None))
992+
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
993+
mocker.patch(
994+
"utils.responses.prepare_tools", return_value=mcp_tools_with_headers
995+
)
996+
mocker.patch("utils.responses.prepare_input", return_value="test")
997+
998+
result = await prepare_responses_params(
999+
mock_client, query_request, None, "token"
1000+
)
1001+
1002+
# The result should contain extra_headers with x-llamastack-provider-data
1003+
dumped = result.model_dump()
1004+
assert (
1005+
dumped["extra_headers"] is not None
1006+
), "extra_headers should not be None when MCP tools have headers"
1007+
assert "x-llamastack-provider-data" in dumped["extra_headers"]
1008+
1009+
provider_data = json.loads(
1010+
dumped["extra_headers"]["x-llamastack-provider-data"]
1011+
)
1012+
assert "mcp_headers" in provider_data
1013+
assert provider_data["mcp_headers"] == {
1014+
"http://aap.foo.redhat.com:8004/sse": {"X-Authorization": "client-token"},
1015+
"http://aap.foo.redhat.com:8005/sse": {"X-Authorization": "client-token-2"},
1016+
}
1017+
1018+
@pytest.mark.asyncio
1019+
async def test_prepare_responses_params_no_extra_headers_without_mcp_tools(
1020+
self, mocker: MockerFixture
1021+
) -> None:
1022+
"""Test that extra_headers is None when no MCP tools have headers."""
1023+
mock_client = mocker.AsyncMock()
1024+
mock_model = mocker.Mock()
1025+
mock_model.id = "provider1/model1"
1026+
mock_model.custom_metadata = {"model_type": "llm", "provider_id": "provider1"}
1027+
mock_client.models.list = mocker.AsyncMock(return_value=[mock_model])
1028+
1029+
mock_conversation = mocker.Mock()
1030+
mock_conversation.id = "new_conv_id"
1031+
mock_client.conversations.create = mocker.AsyncMock(
1032+
return_value=mock_conversation
1033+
)
1034+
1035+
query_request = QueryRequest(query="test") # pyright: ignore[reportCallIssue]
1036+
1037+
mocker.patch("utils.responses.configuration", mocker.Mock())
1038+
mocker.patch(
1039+
"utils.responses.select_model_and_provider_id",
1040+
return_value=("provider1/model1", "model1", "provider1"),
1041+
)
1042+
mocker.patch("utils.responses.evaluate_model_hints", return_value=(None, None))
1043+
mocker.patch("utils.responses.get_system_prompt", return_value="System prompt")
1044+
mocker.patch("utils.responses.prepare_tools", return_value=None)
1045+
mocker.patch("utils.responses.prepare_input", return_value="test")
1046+
1047+
result = await prepare_responses_params(
1048+
mock_client, query_request, None, "token"
1049+
)
1050+
1051+
dumped = result.model_dump()
1052+
assert dumped.get("extra_headers") is None
1053+
9491054
@pytest.mark.asyncio
9501055
async def test_prepare_responses_params_api_status_error_on_conversation(
9511056
self, mocker: MockerFixture

0 commit comments

Comments
 (0)