Skip to content

Commit b3e9962

Browse files
google-genai-botcopybara-github
authored andcommitted
feat(auth): Support additional HTTP headers in MCP tools
PiperOrigin-RevId: 893287579
1 parent 1104523 commit b3e9962

File tree

4 files changed

+74
-0
lines changed

4 files changed

+74
-0
lines changed

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,9 @@ async def _get_headers(
479479
f" {credential.http.credentials.token}"
480480
)
481481
}
482+
if credential.http.additional_headers:
483+
headers = headers or {}
484+
headers.update(credential.http.additional_headers)
482485
elif credential.api_key:
483486
if (
484487
not self._credentials_manager

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ def _get_auth_headers(self) -> Optional[Dict[str, str]]:
240240
f"{credential.http.scheme} {credential.http.credentials.token}"
241241
)
242242
}
243+
244+
if credential.http.additional_headers:
245+
headers = headers or {}
246+
headers.update(credential.http.additional_headers)
243247
elif credential.api_key:
244248
# For API key, use the auth scheme to determine header name
245249
if self._auth_config.auth_scheme:

tests/unittests/tools/mcp_tool/test_mcp_tool.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
from unittest.mock import AsyncMock
17+
from unittest.mock import create_autospec
1718
from unittest.mock import Mock
1819
from unittest.mock import patch
1920

@@ -415,6 +416,44 @@ async def test_get_headers_http_basic(self):
415416
expected_encoded = base64.b64encode(b"user:pass").decode()
416417
assert headers == {"Authorization": f"Basic {expected_encoded}"}
417418

419+
@pytest.mark.asyncio
420+
@pytest.mark.parametrize(
421+
"token, expected_headers",
422+
[
423+
(
424+
"some-token",
425+
{
426+
"Authorization": "some-scheme some-token",
427+
"X-Custom-Header": "custom-value",
428+
},
429+
),
430+
(
431+
None,
432+
{"X-Custom-Header": "custom-value"},
433+
),
434+
],
435+
)
436+
async def test_get_headers_http_adds_additional_headers(
437+
self, token, expected_headers
438+
):
439+
tool = MCPTool(
440+
mcp_tool=self.mock_mcp_tool,
441+
mcp_session_manager=self.mock_session_manager,
442+
)
443+
http_auth = HttpAuth(
444+
scheme="some-scheme",
445+
credentials=HttpCredentials(token=token),
446+
additional_headers={"X-Custom-Header": "custom-value"},
447+
)
448+
credential = AuthCredential(
449+
auth_type=AuthCredentialTypes.HTTP, http=http_auth
450+
)
451+
452+
tool_context = create_autospec(ToolContext, instance=True)
453+
headers = await tool._get_headers(tool_context, credential)
454+
455+
assert headers == expected_headers
456+
418457
@pytest.mark.asyncio
419458
async def test_get_headers_api_key_with_valid_header_scheme(self):
420459
"""Test header generation for API Key credentials with header-based auth scheme."""

tests/unittests/tools/mcp_tool/test_mcp_toolset.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,13 @@
2323
from unittest.mock import Mock
2424
from unittest.mock import patch
2525

26+
from fastapi.openapi.models import OAuth2
2627
from google.adk.agents.readonly_context import ReadonlyContext
2728
from google.adk.auth.auth_credential import AuthCredential
29+
from google.adk.auth.auth_credential import AuthCredentialTypes
30+
from google.adk.auth.auth_credential import HttpAuth
31+
from google.adk.auth.auth_credential import HttpCredentials
32+
from google.adk.auth.auth_tool import AuthConfig
2833
from google.adk.tools.load_mcp_resource_tool import LoadMcpResourceTool
2934
from google.adk.tools.mcp_tool.mcp_session_manager import MCPSessionManager
3035
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
@@ -646,3 +651,26 @@ async def mock_sampling_handler(messages, params=None, context=None):
646651
assert called["value"] is True
647652
assert result["role"] == "assistant"
648653
assert result["content"]["text"] == "sampling response"
654+
655+
@pytest.mark.asyncio
656+
async def test_get_auth_headers_includes_additional_headers(self):
657+
credential = AuthCredential(
658+
auth_type=AuthCredentialTypes.HTTP,
659+
http=HttpAuth(
660+
scheme="bearer",
661+
credentials=HttpCredentials(token="token"),
662+
additional_headers={"X-API-Key": "secret"},
663+
),
664+
)
665+
auth_config = AuthConfig(
666+
auth_scheme=OAuth2(flows={}),
667+
raw_auth_credential=credential,
668+
)
669+
auth_config.exchanged_auth_credential = credential
670+
toolset = McpToolset(connection_params=self.mock_stdio_params)
671+
toolset._auth_config = auth_config
672+
673+
headers = toolset._get_auth_headers()
674+
675+
assert headers["Authorization"] == "Bearer token"
676+
assert headers["X-API-Key"] == "secret"

0 commit comments

Comments
 (0)