Skip to content

Commit e3ed845

Browse files
committed
add auth validation
1 parent 1e034a4 commit e3ed845

5 files changed

Lines changed: 49 additions & 20 deletions

File tree

src/app/endpoints/query_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from authorization.middleware import authorize
2323
from configuration import AppConfig, configuration
2424
from constants import DEFAULT_RAG_TOOL
25-
from models.config import Action
25+
from models.config import Action, ModelContextProtocolServer
2626
from models.requests import QueryRequest
2727
from models.responses import (
2828
ForbiddenResponse,

src/models/config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# pylint: disable=too-many-lines
44

5+
import logging
56
from pathlib import Path
67
from typing import Optional, Any, Pattern
78
from enum import Enum
@@ -32,6 +33,8 @@
3233
from utils import checks
3334
from utils.mcp_auth_headers import resolve_authorization_headers
3435

36+
logger = logging.getLogger(__name__)
37+
3538

3639
class ConfigurationBase(BaseModel):
3740
"""Base class for all configuration models that rejects unknown fields."""
@@ -1641,6 +1644,45 @@ class Configuration(ConfigurationBase):
16411644
description="Quota handlers configuration",
16421645
)
16431646

1647+
@model_validator(mode="after")
1648+
def validate_mcp_auth_headers(self) -> Self:
1649+
"""
1650+
Validate MCP server authorization headers against authentication module.
1651+
1652+
Removes any MCP server with authorization_headers="kubernetes" when the
1653+
authentication module is not "k8s". This prevents sending wrong credential
1654+
types to MCP servers.
1655+
1656+
Returns:
1657+
Self: The model instance after validation.
1658+
"""
1659+
# Get authentication module value (pyright: ignore attribute access on Field)
1660+
auth_module = getattr(self.authentication, "module", None)
1661+
1662+
# Filter out misconfigured MCP servers
1663+
valid_mcp_servers = []
1664+
for mcp_server in self.mcp_servers:
1665+
is_valid = True
1666+
if mcp_server.authorization_headers:
1667+
for value in mcp_server.authorization_headers.values():
1668+
if value.strip() == "kubernetes" and auth_module != "k8s":
1669+
logger.warning(
1670+
"Removing MCP server '%s': has authorization_headers with "
1671+
"value 'kubernetes' but authentication module is '%s' "
1672+
"(not 'k8s'). Either change authentication.module to 'k8s' "
1673+
"or update the MCP server's authorization_headers to use a "
1674+
"file path or 'client'.",
1675+
mcp_server.name,
1676+
auth_module,
1677+
)
1678+
is_valid = False
1679+
break
1680+
if is_valid:
1681+
valid_mcp_servers.append(mcp_server)
1682+
1683+
self.mcp_servers = valid_mcp_servers
1684+
return self
1685+
16441686
def dump(self, filename: str | Path = "configuration.json") -> None:
16451687
"""
16461688
Write the current Configuration model to a JSON file.

tests/unit/app/endpoints/test_query_v2.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -196,22 +196,6 @@ def test_get_mcp_tools_skips_server_with_missing_auth() -> None:
196196
# All servers should be skipped
197197
assert len(tools) == 0
198198

199-
# With token but no mcp_headers
200-
tools = get_mcp_tools(servers, token="k8s-token", mcp_headers=None)
201-
# First server should work, others skipped
202-
assert len(tools) == 1
203-
assert tools[0]["server_label"] == "missing-k8s-auth"
204-
205-
# With mcp_headers but missing one for partial-auth
206-
mcp_headers = {
207-
"missing-client-auth": {"X-Token": "client-token"},
208-
"partial-auth": {"X-Custom": "client-custom"}, # Missing Authorization
209-
}
210-
tools = get_mcp_tools(servers, token=None, mcp_headers=mcp_headers)
211-
# Only missing-client-auth should work
212-
assert len(tools) == 1
213-
assert tools[0]["server_label"] == "missing-client-auth"
214-
215199

216200
def test_get_mcp_tools_includes_server_without_auth() -> None:
217201
"""Test that servers without auth config are always included."""

tests/unit/app/endpoints/test_streaming_query_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ async def test_retrieve_response_builds_rag_and_mcp_tools(
6767
),
6868
]
6969
mocker.patch("app.endpoints.streaming_query_v2.configuration", mock_cfg)
70+
mocker.patch("app.endpoints.query_v2.configuration", mock_cfg)
7071

7172
qr = QueryRequest(query="hello")
7273
await retrieve_response(mock_client, "model-z", qr, token="tok")

tests/unit/models/config/test_model_context_protocol_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
from pydantic import ValidationError
1010

1111
from models.config import ( # type: ignore[import-not-found]
12-
ModelContextProtocolServer,
13-
LlamaStackConfiguration,
14-
UserDataCollection,
12+
AuthenticationConfiguration,
1513
Configuration,
14+
LlamaStackConfiguration,
15+
ModelContextProtocolServer,
1616
ServiceConfiguration,
17+
UserDataCollection,
1718
)
1819

1920

@@ -230,6 +231,7 @@ def test_configuration_mcp_servers_with_mixed_auth_headers(tmp_path: Path) -> No
230231
feedback_enabled=False, feedback_storage=None
231232
),
232233
mcp_servers=mcp_servers,
234+
authentication=AuthenticationConfiguration(module="k8s"),
233235
customization=None,
234236
)
235237
assert cfg is not None

0 commit comments

Comments
 (0)