Skip to content

Commit b3ae1c4

Browse files
committed
addressed comments
1 parent a1770d1 commit b3ae1c4

3 files changed

Lines changed: 268 additions & 44 deletions

File tree

src/utils/mcp_headers.py

Lines changed: 86 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,86 @@ def extract_propagated_headers(
125125
return propagated
126126

127127

128+
def find_unresolved_auth_headers(
129+
configured: Mapping[str, str],
130+
resolved: Mapping[str, str],
131+
) -> list[str]:
132+
"""Return configured auth header names that are absent from the resolved headers.
133+
134+
Comparison is case-insensitive so that ``Authorization`` and ``authorization``
135+
are treated as the same header name.
136+
137+
Args:
138+
configured: The server's ``authorization_headers`` configuration mapping
139+
(header name → secret path or keyword).
140+
resolved: The fully resolved headers that will be sent to the MCP server.
141+
142+
Returns:
143+
List of header names from ``configured`` that could not be resolved, i.e.
144+
are not present as a key in ``resolved``. An empty list means all headers
145+
were resolved successfully.
146+
"""
147+
resolved_lower = {k.lower() for k in resolved}
148+
return [h for h in configured if h.lower() not in resolved_lower]
149+
150+
151+
def build_server_headers(
152+
mcp_server: ModelContextProtocolServer,
153+
client_headers: dict[str, str],
154+
request_headers: Optional[Mapping[str, str]],
155+
token: Optional[str],
156+
) -> dict[str, str]:
157+
"""Build the complete set of headers for a single MCP server.
158+
159+
Merges client-supplied headers, resolved authorization headers, and propagated
160+
request headers in priority order (highest first):
161+
162+
1. Client-supplied headers (already present in ``client_headers``).
163+
2. Statically resolved authorization headers from configuration.
164+
3. Kubernetes Bearer token for headers configured with the ``kubernetes`` keyword.
165+
``client`` and ``oauth`` keywords are skipped — those values are already
166+
provided by the client in source 1.
167+
4. Headers propagated from the incoming request via the server's allowlist.
168+
169+
Args:
170+
mcp_server: MCP server configuration.
171+
client_headers: Headers already supplied by the client for this server.
172+
request_headers: Headers from the incoming HTTP request, or ``None``.
173+
token: Optional Kubernetes service-account token.
174+
175+
Returns:
176+
Merged headers dictionary for the server. May be empty if no headers apply.
177+
"""
178+
server_headers: dict[str, str] = dict(client_headers)
179+
existing_lower = {k.lower() for k in server_headers}
180+
181+
for (
182+
header_name,
183+
resolved_value,
184+
) in mcp_server.resolved_authorization_headers.items():
185+
if header_name.lower() in existing_lower:
186+
continue
187+
match resolved_value:
188+
case constants.MCP_AUTH_KUBERNETES:
189+
if token:
190+
server_headers[header_name] = f"Bearer {token}"
191+
existing_lower.add(header_name.lower())
192+
case constants.MCP_AUTH_CLIENT | constants.MCP_AUTH_OAUTH:
193+
pass # client-provided; already included via the initial client_headers copy
194+
case _:
195+
server_headers[header_name] = resolved_value
196+
existing_lower.add(header_name.lower())
197+
198+
if mcp_server.headers and request_headers is not None:
199+
propagated = extract_propagated_headers(mcp_server, request_headers)
200+
for h_name, h_value in propagated.items():
201+
if h_name.lower() not in existing_lower:
202+
server_headers[h_name] = h_value
203+
existing_lower.add(h_name.lower())
204+
205+
return server_headers
206+
207+
128208
def build_mcp_headers(
129209
config: AppConfig,
130210
mcp_headers: McpHeaders,
@@ -162,34 +242,12 @@ def build_mcp_headers(
162242
complete: McpHeaders = {}
163243

164244
for mcp_server in config.mcp_servers:
165-
server_headers: dict[str, str] = dict(mcp_headers.get(mcp_server.name, {}))
166-
existing_lower = {k.lower() for k in server_headers}
167-
168-
for (
169-
header_name,
170-
resolved_value,
171-
) in mcp_server.resolved_authorization_headers.items():
172-
if header_name.lower() in existing_lower:
173-
continue
174-
match resolved_value:
175-
case constants.MCP_AUTH_KUBERNETES:
176-
if token:
177-
server_headers[header_name] = f"Bearer {token}"
178-
existing_lower.add(header_name.lower())
179-
case constants.MCP_AUTH_CLIENT | constants.MCP_AUTH_OAUTH:
180-
pass # client-provided; already included via the initial mcp_headers copy
181-
case _:
182-
server_headers[header_name] = resolved_value
183-
existing_lower.add(header_name.lower())
184-
185-
# Propagate allowlisted headers from the incoming request.
186-
if mcp_server.headers and request_headers is not None:
187-
propagated = extract_propagated_headers(mcp_server, request_headers)
188-
for h_name, h_value in propagated.items():
189-
if h_name.lower() not in existing_lower:
190-
server_headers[h_name] = h_value
191-
existing_lower.add(h_name.lower())
192-
245+
server_headers = build_server_headers(
246+
mcp_server,
247+
dict(mcp_headers.get(mcp_server.name, {})),
248+
request_headers,
249+
token,
250+
)
193251
if server_headers:
194252
complete[mcp_server.name] = server_headers
195253

src/utils/responses.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,11 @@
5050
NotFoundResponse,
5151
ServiceUnavailableResponse,
5252
)
53-
from utils.mcp_headers import McpHeaders, build_mcp_headers
53+
from utils.mcp_headers import (
54+
McpHeaders,
55+
build_mcp_headers,
56+
find_unresolved_auth_headers,
57+
)
5458
from utils.prompts import get_system_prompt, get_topic_summary_system_prompt
5559
from utils.query import (
5660
extract_provider_and_model_from_model_id,
@@ -471,20 +475,17 @@ async def get_mcp_tools(
471475
headers: dict[str, str] = dict(complete_headers.get(mcp_server.name, {}))
472476

473477
# Skip server if any configured auth header could not be resolved.
474-
if mcp_server.authorization_headers:
475-
unresolved = [
476-
h
477-
for h in mcp_server.authorization_headers
478-
if not any(k.lower() == h.lower() for k in headers)
479-
]
480-
if unresolved:
481-
logger.warning(
482-
"Skipping MCP server %s: required %d auth headers but only resolved %d",
483-
mcp_server.name,
484-
len(mcp_server.authorization_headers),
485-
len(mcp_server.authorization_headers) - len(unresolved),
486-
)
487-
continue
478+
unresolved = find_unresolved_auth_headers(
479+
mcp_server.authorization_headers, headers
480+
)
481+
if unresolved:
482+
logger.warning(
483+
"Skipping MCP server %s: required %d auth headers but only resolved %d",
484+
mcp_server.name,
485+
len(mcp_server.authorization_headers),
486+
len(mcp_server.authorization_headers) - len(unresolved),
487+
)
488+
continue
488489

489490
authorization = headers.pop("Authorization", None)
490491
tools.append(

tests/unit/utils/test_mcp_headers.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,14 @@
55

66
from fastapi import Request
77

8+
import constants
89
from models.config import ModelContextProtocolServer
910
from utils import mcp_headers
10-
from utils.mcp_headers import extract_propagated_headers
11+
from utils.mcp_headers import (
12+
build_server_headers,
13+
extract_propagated_headers,
14+
find_unresolved_auth_headers,
15+
)
1116

1217

1318
def test_extract_mcp_headers_empty_headers(mocker: MockerFixture) -> None:
@@ -289,3 +294,163 @@ def test_no_headers_field_configured(self) -> None:
289294
request_headers = {"x-rh-identity": "identity-value"}
290295
result = extract_propagated_headers(server, request_headers)
291296
assert not result
297+
298+
299+
class TestFindUnresolvedAuthHeaders:
300+
"""Test cases for find_unresolved_auth_headers function."""
301+
302+
def test_all_configured_headers_present(self) -> None:
303+
"""Test that an empty list is returned when all configured headers are resolved."""
304+
configured = {"Authorization": "kubernetes", "X-Api-Key": "/var/secrets/key"}
305+
resolved = {"Authorization": "Bearer tok", "X-Api-Key": "secret"}
306+
assert not find_unresolved_auth_headers(configured, resolved)
307+
308+
def test_missing_header_is_returned(self) -> None:
309+
"""Test that a configured header absent from resolved is returned."""
310+
configured = {"Authorization": "kubernetes"}
311+
resolved: dict[str, str] = {}
312+
assert find_unresolved_auth_headers(configured, resolved) == ["Authorization"]
313+
314+
def test_partially_resolved_returns_missing(self) -> None:
315+
"""Test that only unresolved headers are returned when some are resolved."""
316+
configured = {"Authorization": "kubernetes", "X-Api-Key": "/var/secrets/key"}
317+
resolved = {"Authorization": "Bearer tok"}
318+
assert find_unresolved_auth_headers(configured, resolved) == ["X-Api-Key"]
319+
320+
def test_comparison_is_case_insensitive(self) -> None:
321+
"""Test that header name matching is case-insensitive."""
322+
configured = {"Authorization": "kubernetes"}
323+
resolved = {"authorization": "Bearer tok"}
324+
assert not find_unresolved_auth_headers(configured, resolved)
325+
326+
def test_empty_configured_returns_empty(self) -> None:
327+
"""Test that an empty configured dict returns an empty list."""
328+
assert not find_unresolved_auth_headers({}, {"Authorization": "Bearer tok"})
329+
330+
def test_empty_resolved_returns_all_configured(self) -> None:
331+
"""Test that all configured headers are returned when resolved is empty."""
332+
configured = {"Authorization": "kubernetes", "X-Api-Key": "/path"}
333+
result = find_unresolved_auth_headers(configured, {})
334+
assert sorted(result) == ["Authorization", "X-Api-Key"]
335+
336+
337+
class TestBuildServerHeaders:
338+
"""Test cases for build_server_headers function."""
339+
340+
def _make_server(
341+
self,
342+
resolved_auth: dict[str, str] | None = None,
343+
headers: list[str] | None = None,
344+
) -> ModelContextProtocolServer:
345+
"""Create a ModelContextProtocolServer with given auth and allowlist headers."""
346+
server = ModelContextProtocolServer(
347+
name="test-server",
348+
url="http://test:8080",
349+
provider_id="xyzzy",
350+
headers=headers or [],
351+
)
352+
object.__setattr__(
353+
server, "_resolved_authorization_headers", resolved_auth or {}
354+
)
355+
return server
356+
357+
def test_static_resolved_header_is_added(self) -> None:
358+
"""Test that a statically resolved header value is included in the result."""
359+
server = self._make_server(resolved_auth={"Authorization": "static-token"})
360+
result = build_server_headers(server, {}, None, None)
361+
assert result == {"Authorization": "static-token"}
362+
363+
def test_kubernetes_token_resolves_to_bearer(self) -> None:
364+
"""Test that a kubernetes keyword resolves to a Bearer token."""
365+
server = self._make_server(
366+
resolved_auth={"Authorization": constants.MCP_AUTH_KUBERNETES}
367+
)
368+
result = build_server_headers(server, {}, None, token="my-k8s-token")
369+
assert result == {"Authorization": "Bearer my-k8s-token"}
370+
371+
def test_kubernetes_without_token_is_skipped(self) -> None:
372+
"""Test that a kubernetes keyword with no token produces no header."""
373+
server = self._make_server(
374+
resolved_auth={"Authorization": constants.MCP_AUTH_KUBERNETES}
375+
)
376+
result = build_server_headers(server, {}, None, token=None)
377+
assert not result
378+
379+
def test_client_keyword_is_skipped(self) -> None:
380+
"""Test that a client keyword is skipped (value comes from client_headers)."""
381+
server = self._make_server(
382+
resolved_auth={"Authorization": constants.MCP_AUTH_CLIENT}
383+
)
384+
result = build_server_headers(server, {}, None, None)
385+
assert not result
386+
387+
def test_oauth_keyword_is_skipped(self) -> None:
388+
"""Test that an oauth keyword is skipped (value comes from client_headers)."""
389+
server = self._make_server(
390+
resolved_auth={"Authorization": constants.MCP_AUTH_OAUTH}
391+
)
392+
result = build_server_headers(server, {}, None, None)
393+
assert not result
394+
395+
def test_client_headers_take_priority_over_resolved(self) -> None:
396+
"""Test that a client-supplied header is not overwritten by a resolved value."""
397+
server = self._make_server(resolved_auth={"Authorization": "static-token"})
398+
result = build_server_headers(
399+
server, {"Authorization": "client-token"}, None, None
400+
)
401+
assert result == {"Authorization": "client-token"}
402+
403+
def test_client_headers_priority_is_case_insensitive(self) -> None:
404+
"""Test that case-insensitive comparison prevents overwriting client headers."""
405+
server = self._make_server(resolved_auth={"authorization": "static-token"})
406+
result = build_server_headers(
407+
server, {"Authorization": "client-token"}, None, None
408+
)
409+
assert result == {"Authorization": "client-token"}
410+
411+
def test_propagated_request_headers_are_added(self) -> None:
412+
"""Test that allowlisted request headers are propagated."""
413+
server = self._make_server(headers=["x-rh-identity"])
414+
result = build_server_headers(
415+
server, {}, {"x-rh-identity": "my-identity"}, None
416+
)
417+
assert result == {"x-rh-identity": "my-identity"}
418+
419+
def test_existing_header_blocks_propagation(self) -> None:
420+
"""Test that a propagated header does not overwrite an already-set header."""
421+
server = self._make_server(headers=["x-rh-identity"])
422+
result = build_server_headers(
423+
server,
424+
{"x-rh-identity": "client-identity"},
425+
{"x-rh-identity": "request-identity"},
426+
None,
427+
)
428+
assert result == {"x-rh-identity": "client-identity"}
429+
430+
def test_no_headers_no_config_returns_empty(self) -> None:
431+
"""Test that a server with no applicable headers returns an empty dict."""
432+
server = self._make_server()
433+
result = build_server_headers(server, {}, None, None)
434+
assert not result
435+
436+
def test_multiple_sources_are_merged(self) -> None:
437+
"""Test that all header sources are combined into one dictionary."""
438+
server = self._make_server(
439+
resolved_auth={
440+
"Authorization": constants.MCP_AUTH_KUBERNETES,
441+
"X-Api-Key": "static-key",
442+
},
443+
headers=["x-request-id"],
444+
)
445+
result = build_server_headers(
446+
server,
447+
{"X-Client-Header": "client-value"},
448+
{"x-request-id": "req-123"},
449+
token="k8s-token",
450+
)
451+
assert result == {
452+
"X-Client-Header": "client-value",
453+
"Authorization": "Bearer k8s-token",
454+
"X-Api-Key": "static-key",
455+
"x-request-id": "req-123",
456+
}

0 commit comments

Comments
 (0)