Skip to content

Commit 7913a3b

Browse files
wukathcopybara-github
authored andcommitted
feat: Add auth scheme/credential support to MCP toolsets in Agent Registry
This is also so that we can add Agent Identity support when it's ready. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 894180206
1 parent f641b1a commit 7913a3b

File tree

2 files changed

+68
-21
lines changed

2 files changed

+68
-21
lines changed

src/google/adk/integrations/agent_registry/agent_registry.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@
2525
from typing import Dict
2626
from typing import List
2727
from typing import Mapping
28-
from typing import Optional
2928
from typing import Sequence
3029
from typing import TypedDict
31-
from typing import Union
3230
from urllib.parse import parse_qs
3331
from urllib.parse import urlparse
3432

@@ -39,6 +37,8 @@
3937
from a2a.types import TransportProtocol as A2ATransport
4038
from google.adk.agents.readonly_context import ReadonlyContext
4139
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
40+
from google.adk.auth.auth_credential import AuthCredential
41+
from google.adk.auth.auth_schemes import AuthScheme
4242
from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID
4343
from google.adk.tools.base_tool import BaseTool
4444
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
@@ -75,11 +75,15 @@ def __init__(
7575
header_provider: (
7676
Callable[[ReadonlyContext], Dict[str, str]] | None
7777
) = None,
78+
auth_scheme: AuthScheme | None = None,
79+
auth_credential: AuthCredential | None = None,
7880
):
7981
super().__init__(
8082
connection_params=connection_params,
8183
tool_name_prefix=tool_name_prefix,
8284
header_provider=header_provider,
85+
auth_scheme=auth_scheme,
86+
auth_credential=auth_credential,
8387
)
8488
self.destination_resource_id = destination_resource_id
8589

@@ -143,11 +147,11 @@ class AgentRegistry:
143147

144148
def __init__(
145149
self,
146-
project_id: Optional[str] = None,
147-
location: Optional[str] = None,
148-
header_provider: Optional[
149-
Callable[[ReadonlyContext], Dict[str, str]]
150-
] = None,
150+
project_id: str | None = None,
151+
location: str | None = None,
152+
header_provider: (
153+
Callable[[ReadonlyContext], Dict[str, str]] | None
154+
) = None,
151155
):
152156
"""Initializes the AgentRegistry client.
153157
@@ -190,7 +194,7 @@ def _get_auth_headers(self) -> Dict[str, str]:
190194
) from e
191195

192196
def _make_request(
193-
self, path: str, params: Optional[Dict[str, Any]] = None
197+
self, path: str, params: Dict[str, Any] | None = None
194198
) -> Dict[str, Any]:
195199
"""Helper function to make GET requests to the Agent Registry API."""
196200
if path.startswith("projects/"):
@@ -217,9 +221,9 @@ def _make_request(
217221
def _get_connection_uri(
218222
self,
219223
resource_details: Mapping[str, Any],
220-
protocol_type: Optional[_ProtocolType] = None,
221-
protocol_binding: Optional[A2ATransport] = None,
222-
) -> Optional[str]:
224+
protocol_type: _ProtocolType | None = None,
225+
protocol_binding: A2ATransport | None = None,
226+
) -> str | None:
223227
"""Extracts the first matching URI based on type and binding filters."""
224228
protocols = list(resource_details.get("protocols", []))
225229
if "interfaces" in resource_details:
@@ -249,9 +253,9 @@ def _clean_name(self, name: str) -> str:
249253

250254
def list_mcp_servers(
251255
self,
252-
filter_str: Optional[str] = None,
253-
page_size: Optional[int] = None,
254-
page_token: Optional[str] = None,
256+
filter_str: str | None = None,
257+
page_size: int | None = None,
258+
page_token: str | None = None,
255259
) -> Dict[str, Any]:
256260
"""Fetches a list of MCP Servers."""
257261
params = {}
@@ -267,7 +271,12 @@ def get_mcp_server(self, name: str) -> Dict[str, Any]:
267271
"""Retrieves details of a specific MCP Server."""
268272
return self._make_request(name)
269273

270-
def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset:
274+
def get_mcp_toolset(
275+
self,
276+
mcp_server_name: str,
277+
auth_scheme: AuthScheme | None = None,
278+
auth_credential: AuthCredential | None = None,
279+
) -> McpToolset:
271280
"""Constructs an McpToolset instance from a registered MCP Server."""
272281
server_details = self.get_mcp_server(mcp_server_name)
273282
name = self._clean_name(server_details.get("displayName", mcp_server_name))
@@ -293,15 +302,17 @@ def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset:
293302
connection_params=connection_params,
294303
tool_name_prefix=name,
295304
header_provider=self._header_provider,
305+
auth_scheme=auth_scheme,
306+
auth_credential=auth_credential,
296307
)
297308

298309
# --- Endpoint Methods ---
299310

300311
def list_endpoints(
301312
self,
302-
filter_str: Optional[str] = None,
303-
page_size: Optional[int] = None,
304-
page_token: Optional[str] = None,
313+
filter_str: str | None = None,
314+
page_size: int | None = None,
315+
page_token: str | None = None,
305316
) -> Dict[str, Any]:
306317
"""Fetches a list of Endpoints."""
307318
params = {}
@@ -349,9 +360,9 @@ def get_model_name(self, endpoint_name: str) -> str:
349360

350361
def list_agents(
351362
self,
352-
filter_str: Optional[str] = None,
353-
page_size: Optional[int] = None,
354-
page_token: Optional[str] = None,
363+
filter_str: str | None = None,
364+
page_size: int | None = None,
365+
page_token: str | None = None,
355366
) -> Dict[str, Any]:
356367
"""Fetches a list of registered A2A Agents."""
357368
params = {}

tests/unittests/integrations/agent_registry/test_agent_registry.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,10 @@
1818
from unittest.mock import patch
1919

2020
from a2a.types import TransportProtocol as A2ATransport
21+
from fastapi.openapi.models import OAuth2
2122
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
23+
from google.adk.auth.auth_credential import AuthCredential
24+
from google.adk.auth.auth_credential import OAuth2Auth
2225
from google.adk.integrations.agent_registry import _ProtocolType
2326
from google.adk.integrations.agent_registry import AgentRegistry
2427
from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID
@@ -325,6 +328,39 @@ def test_get_mcp_toolset(self, mock_httpx, registry):
325328
assert isinstance(toolset, McpToolset)
326329
assert toolset.tool_name_prefix == "TestPrefix"
327330

331+
@patch("httpx.Client")
332+
def test_get_mcp_toolset_with_auth(self, mock_httpx, registry):
333+
mock_response = MagicMock()
334+
mock_response.json.return_value = {
335+
"displayName": "TestPrefix",
336+
"interfaces": [{
337+
"url": "https://mcp.com",
338+
"protocolBinding": A2ATransport.jsonrpc,
339+
}],
340+
}
341+
mock_response.raise_for_status = MagicMock()
342+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
343+
mock_response
344+
)
345+
346+
registry._credentials.token = "token"
347+
registry._credentials.refresh = MagicMock()
348+
349+
auth_scheme = OAuth2(flows={})
350+
auth_credential = AuthCredential(
351+
auth_type="oauth2",
352+
oauth2=OAuth2Auth(client_id="test_id", client_secret="test_secret"),
353+
)
354+
355+
toolset = registry.get_mcp_toolset(
356+
"test-mcp", auth_scheme=auth_scheme, auth_credential=auth_credential
357+
)
358+
assert isinstance(toolset, McpToolset)
359+
auth_config = toolset.get_auth_config()
360+
assert auth_config is not None
361+
assert auth_config.auth_scheme == auth_scheme
362+
assert auth_config.raw_auth_credential == auth_credential
363+
328364
@patch("httpx.Client")
329365
def test_get_remote_a2a_agent(self, mock_httpx, registry):
330366
mock_response = MagicMock()

0 commit comments

Comments
 (0)