|
5 | 5 |
|
6 | 6 | from fastapi import Request |
7 | 7 |
|
| 8 | +import constants |
8 | 9 | from models.config import ModelContextProtocolServer |
9 | 10 | from utils import mcp_headers |
10 | | -from utils.mcp_headers import extract_propagated_headers |
| 11 | +from utils.mcp_headers import ( |
| 12 | + build_server_headers, |
| 13 | + extract_propagated_headers, |
| 14 | + find_unresolved_auth_headers, |
| 15 | +) |
11 | 16 |
|
12 | 17 |
|
13 | 18 | def test_extract_mcp_headers_empty_headers(mocker: MockerFixture) -> None: |
@@ -289,3 +294,163 @@ def test_no_headers_field_configured(self) -> None: |
289 | 294 | request_headers = {"x-rh-identity": "identity-value"} |
290 | 295 | result = extract_propagated_headers(server, request_headers) |
291 | 296 | assert not result |
| 297 | + |
| 298 | + |
| 299 | +class TestFindUnresolvedAuthHeaders: |
| 300 | + """Test cases for find_unresolved_auth_headers function.""" |
| 301 | + |
| 302 | + def test_all_configured_headers_present(self) -> None: |
| 303 | + """Test that an empty list is returned when all configured headers are resolved.""" |
| 304 | + configured = {"Authorization": "kubernetes", "X-Api-Key": "/var/secrets/key"} |
| 305 | + resolved = {"Authorization": "Bearer tok", "X-Api-Key": "secret"} |
| 306 | + assert not find_unresolved_auth_headers(configured, resolved) |
| 307 | + |
| 308 | + def test_missing_header_is_returned(self) -> None: |
| 309 | + """Test that a configured header absent from resolved is returned.""" |
| 310 | + configured = {"Authorization": "kubernetes"} |
| 311 | + resolved: dict[str, str] = {} |
| 312 | + assert find_unresolved_auth_headers(configured, resolved) == ["Authorization"] |
| 313 | + |
| 314 | + def test_partially_resolved_returns_missing(self) -> None: |
| 315 | + """Test that only unresolved headers are returned when some are resolved.""" |
| 316 | + configured = {"Authorization": "kubernetes", "X-Api-Key": "/var/secrets/key"} |
| 317 | + resolved = {"Authorization": "Bearer tok"} |
| 318 | + assert find_unresolved_auth_headers(configured, resolved) == ["X-Api-Key"] |
| 319 | + |
| 320 | + def test_comparison_is_case_insensitive(self) -> None: |
| 321 | + """Test that header name matching is case-insensitive.""" |
| 322 | + configured = {"Authorization": "kubernetes"} |
| 323 | + resolved = {"authorization": "Bearer tok"} |
| 324 | + assert not find_unresolved_auth_headers(configured, resolved) |
| 325 | + |
| 326 | + def test_empty_configured_returns_empty(self) -> None: |
| 327 | + """Test that an empty configured dict returns an empty list.""" |
| 328 | + assert not find_unresolved_auth_headers({}, {"Authorization": "Bearer tok"}) |
| 329 | + |
| 330 | + def test_empty_resolved_returns_all_configured(self) -> None: |
| 331 | + """Test that all configured headers are returned when resolved is empty.""" |
| 332 | + configured = {"Authorization": "kubernetes", "X-Api-Key": "/path"} |
| 333 | + result = find_unresolved_auth_headers(configured, {}) |
| 334 | + assert sorted(result) == ["Authorization", "X-Api-Key"] |
| 335 | + |
| 336 | + |
| 337 | +class TestBuildServerHeaders: |
| 338 | + """Test cases for build_server_headers function.""" |
| 339 | + |
| 340 | + def _make_server( |
| 341 | + self, |
| 342 | + resolved_auth: dict[str, str] | None = None, |
| 343 | + headers: list[str] | None = None, |
| 344 | + ) -> ModelContextProtocolServer: |
| 345 | + """Create a ModelContextProtocolServer with given auth and allowlist headers.""" |
| 346 | + server = ModelContextProtocolServer( |
| 347 | + name="test-server", |
| 348 | + url="http://test:8080", |
| 349 | + provider_id="xyzzy", |
| 350 | + headers=headers or [], |
| 351 | + ) |
| 352 | + object.__setattr__( |
| 353 | + server, "_resolved_authorization_headers", resolved_auth or {} |
| 354 | + ) |
| 355 | + return server |
| 356 | + |
| 357 | + def test_static_resolved_header_is_added(self) -> None: |
| 358 | + """Test that a statically resolved header value is included in the result.""" |
| 359 | + server = self._make_server(resolved_auth={"Authorization": "static-token"}) |
| 360 | + result = build_server_headers(server, {}, None, None) |
| 361 | + assert result == {"Authorization": "static-token"} |
| 362 | + |
| 363 | + def test_kubernetes_token_resolves_to_bearer(self) -> None: |
| 364 | + """Test that a kubernetes keyword resolves to a Bearer token.""" |
| 365 | + server = self._make_server( |
| 366 | + resolved_auth={"Authorization": constants.MCP_AUTH_KUBERNETES} |
| 367 | + ) |
| 368 | + result = build_server_headers(server, {}, None, token="my-k8s-token") |
| 369 | + assert result == {"Authorization": "Bearer my-k8s-token"} |
| 370 | + |
| 371 | + def test_kubernetes_without_token_is_skipped(self) -> None: |
| 372 | + """Test that a kubernetes keyword with no token produces no header.""" |
| 373 | + server = self._make_server( |
| 374 | + resolved_auth={"Authorization": constants.MCP_AUTH_KUBERNETES} |
| 375 | + ) |
| 376 | + result = build_server_headers(server, {}, None, token=None) |
| 377 | + assert not result |
| 378 | + |
| 379 | + def test_client_keyword_is_skipped(self) -> None: |
| 380 | + """Test that a client keyword is skipped (value comes from client_headers).""" |
| 381 | + server = self._make_server( |
| 382 | + resolved_auth={"Authorization": constants.MCP_AUTH_CLIENT} |
| 383 | + ) |
| 384 | + result = build_server_headers(server, {}, None, None) |
| 385 | + assert not result |
| 386 | + |
| 387 | + def test_oauth_keyword_is_skipped(self) -> None: |
| 388 | + """Test that an oauth keyword is skipped (value comes from client_headers).""" |
| 389 | + server = self._make_server( |
| 390 | + resolved_auth={"Authorization": constants.MCP_AUTH_OAUTH} |
| 391 | + ) |
| 392 | + result = build_server_headers(server, {}, None, None) |
| 393 | + assert not result |
| 394 | + |
| 395 | + def test_client_headers_take_priority_over_resolved(self) -> None: |
| 396 | + """Test that a client-supplied header is not overwritten by a resolved value.""" |
| 397 | + server = self._make_server(resolved_auth={"Authorization": "static-token"}) |
| 398 | + result = build_server_headers( |
| 399 | + server, {"Authorization": "client-token"}, None, None |
| 400 | + ) |
| 401 | + assert result == {"Authorization": "client-token"} |
| 402 | + |
| 403 | + def test_client_headers_priority_is_case_insensitive(self) -> None: |
| 404 | + """Test that case-insensitive comparison prevents overwriting client headers.""" |
| 405 | + server = self._make_server(resolved_auth={"authorization": "static-token"}) |
| 406 | + result = build_server_headers( |
| 407 | + server, {"Authorization": "client-token"}, None, None |
| 408 | + ) |
| 409 | + assert result == {"Authorization": "client-token"} |
| 410 | + |
| 411 | + def test_propagated_request_headers_are_added(self) -> None: |
| 412 | + """Test that allowlisted request headers are propagated.""" |
| 413 | + server = self._make_server(headers=["x-rh-identity"]) |
| 414 | + result = build_server_headers( |
| 415 | + server, {}, {"x-rh-identity": "my-identity"}, None |
| 416 | + ) |
| 417 | + assert result == {"x-rh-identity": "my-identity"} |
| 418 | + |
| 419 | + def test_existing_header_blocks_propagation(self) -> None: |
| 420 | + """Test that a propagated header does not overwrite an already-set header.""" |
| 421 | + server = self._make_server(headers=["x-rh-identity"]) |
| 422 | + result = build_server_headers( |
| 423 | + server, |
| 424 | + {"x-rh-identity": "client-identity"}, |
| 425 | + {"x-rh-identity": "request-identity"}, |
| 426 | + None, |
| 427 | + ) |
| 428 | + assert result == {"x-rh-identity": "client-identity"} |
| 429 | + |
| 430 | + def test_no_headers_no_config_returns_empty(self) -> None: |
| 431 | + """Test that a server with no applicable headers returns an empty dict.""" |
| 432 | + server = self._make_server() |
| 433 | + result = build_server_headers(server, {}, None, None) |
| 434 | + assert not result |
| 435 | + |
| 436 | + def test_multiple_sources_are_merged(self) -> None: |
| 437 | + """Test that all header sources are combined into one dictionary.""" |
| 438 | + server = self._make_server( |
| 439 | + resolved_auth={ |
| 440 | + "Authorization": constants.MCP_AUTH_KUBERNETES, |
| 441 | + "X-Api-Key": "static-key", |
| 442 | + }, |
| 443 | + headers=["x-request-id"], |
| 444 | + ) |
| 445 | + result = build_server_headers( |
| 446 | + server, |
| 447 | + {"X-Client-Header": "client-value"}, |
| 448 | + {"x-request-id": "req-123"}, |
| 449 | + token="k8s-token", |
| 450 | + ) |
| 451 | + assert result == { |
| 452 | + "X-Client-Header": "client-value", |
| 453 | + "Authorization": "Bearer k8s-token", |
| 454 | + "X-Api-Key": "static-key", |
| 455 | + "x-request-id": "req-123", |
| 456 | + } |
0 commit comments