Skip to content

Commit 6754760

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Support AgentRegistry association
PiperOrigin-RevId: 890595289
1 parent 4ffe8fb commit 6754760

File tree

4 files changed

+307
-8
lines changed

4 files changed

+307
-8
lines changed

src/google/adk/integrations/agent_registry/agent_registry.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,70 @@
3737
from a2a.types import TransportProtocol as A2ATransport
3838
from google.adk.agents.readonly_context import ReadonlyContext
3939
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
40+
from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID
41+
from google.adk.tools.base_tool import BaseTool
42+
from google.adk.tools.mcp_tool.mcp_session_manager import SseConnectionParams
43+
from google.adk.tools.mcp_tool.mcp_session_manager import StdioConnectionParams
4044
from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams
4145
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
4246
import google.auth
4347
import google.auth.transport.requests
4448
import httpx
49+
from mcp import StdioServerParameters
50+
from typing_extensions import override
4551

4652
logger = logging.getLogger("google_adk." + __name__)
4753

4854
AGENT_REGISTRY_BASE_URL = "https://agentregistry.googleapis.com/v1alpha"
4955

5056

57+
# An MCPToolset for a single registered MCP server. Adds the special
58+
# gcp.mcp.server.destination.id custom_metadata key on each returned tool. This special key is
59+
# added to execute_tool spans in google.adk.telemetry.tracing
60+
class AgentRegistrySingleMcpToolset(McpToolset):
61+
62+
def __init__(
63+
self,
64+
*,
65+
destination_resource_id: str | None,
66+
connection_params: (
67+
StdioServerParameters
68+
| StdioConnectionParams
69+
| SseConnectionParams
70+
| StreamableHTTPConnectionParams
71+
),
72+
tool_name_prefix: str | None = None,
73+
header_provider: (
74+
Callable[[ReadonlyContext], Dict[str, str]] | None
75+
) = None,
76+
):
77+
super().__init__(
78+
connection_params=connection_params,
79+
tool_name_prefix=tool_name_prefix,
80+
header_provider=header_provider,
81+
)
82+
self.destination_resource_id = destination_resource_id
83+
84+
@override
85+
async def get_tools(
86+
self, readonly_context: ReadonlyContext | None = None
87+
) -> List[BaseTool]:
88+
tools = await super().get_tools(readonly_context)
89+
90+
# Noop if there is no destination_resource_id
91+
if self.destination_resource_id is None:
92+
return tools
93+
94+
for tool in tools:
95+
if not tool.custom_metadata:
96+
tool.custom_metadata = {}
97+
98+
tool.custom_metadata[GCP_MCP_SERVER_DESTINATION_ID] = (
99+
self.destination_resource_id
100+
)
101+
return tools
102+
103+
51104
class _ProtocolType(str, Enum):
52105
"""Supported agent protocol types."""
53106

@@ -196,6 +249,9 @@ def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset:
196249
"""Constructs an McpToolset instance from a registered MCP Server."""
197250
server_details = self.get_mcp_server(mcp_server_name)
198251
name = self._clean_name(server_details.get("displayName", mcp_server_name))
252+
mcp_server_id = server_details.get("mcpServerId")
253+
if not isinstance(mcp_server_id, str):
254+
mcp_server_id = None
199255

200256
endpoint_uri = self._get_connection_uri(
201257
server_details, protocol_binding=A2ATransport.jsonrpc
@@ -210,7 +266,8 @@ def get_mcp_toolset(self, mcp_server_name: str) -> McpToolset:
210266
connection_params = StreamableHTTPConnectionParams(
211267
url=endpoint_uri, headers=self._get_auth_headers()
212268
)
213-
return McpToolset(
269+
return AgentRegistrySingleMcpToolset(
270+
destination_resource_id=mcp_server_id,
214271
connection_params=connection_params,
215272
tool_name_prefix=name,
216273
header_provider=self._header_provider,

src/google/adk/telemetry/tracing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@
8383

8484
USER_CONTENT_ELIDED = '<elided>'
8585

86+
# Used to associate a span with a destination resource for AppHub. Tools with
87+
# this key in their BaseTool.custom_metadata will have the mapping added as a
88+
# span attribute
89+
GCP_MCP_SERVER_DESTINATION_ID = 'gcp.mcp.server.destination.id'
90+
8691
# Needed to avoid circular imports
8792
if TYPE_CHECKING:
8893
from ..agents.base_agent import BaseAgent
@@ -190,6 +195,14 @@ def trace_tool_call(
190195
else:
191196
span.set_attribute(ERROR_TYPE, type(error).__name__)
192197

198+
# Special case for client side association with a remote tool call
199+
if (
200+
tool.custom_metadata
201+
and GCP_MCP_SERVER_DESTINATION_ID in tool.custom_metadata
202+
):
203+
destination_id = tool.custom_metadata[GCP_MCP_SERVER_DESTINATION_ID]
204+
span.set_attribute(GCP_MCP_SERVER_DESTINATION_ID, destination_id)
205+
193206
# Setting empty llm request and response (as UI expect these) while not
194207
# applicable for tool_response.
195208
span.set_attribute('gcp.vertex.agent.llm_request', '{}')

tests/unittests/integrations/agent_registry/test_agent_registry.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
16+
from unittest.mock import AsyncMock
1517
from unittest.mock import MagicMock
1618
from unittest.mock import patch
1719

1820
from a2a.types import TransportProtocol as A2ATransport
1921
from google.adk.agents.remote_a2a_agent import RemoteA2aAgent
2022
from google.adk.integrations.agent_registry import _ProtocolType
2123
from google.adk.integrations.agent_registry import AgentRegistry
24+
from google.adk.telemetry.tracing import GCP_MCP_SERVER_DESTINATION_ID
2225
from google.adk.tools.mcp_tool.mcp_toolset import McpToolset
2326
import httpx
27+
from mcp import ClientSession
28+
from mcp.types import ListToolsResult
29+
from mcp.types import Tool
2430
import pytest
2531

2632

@@ -31,6 +37,130 @@ def registry(self):
3137
with patch("google.auth.default", return_value=(MagicMock(), "project-id")):
3238
return AgentRegistry(project_id="test-project", location="global")
3339

40+
@pytest.mark.asyncio
41+
@patch("httpx.Client")
42+
@patch(
43+
"google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session",
44+
new_callable=AsyncMock,
45+
)
46+
async def test_get_mcp_toolset_adds_destination_id(
47+
self, mock_create_session, mock_httpx, registry
48+
):
49+
"""Test that tools from get_mcp_toolset have the destination ID."""
50+
# Arrange
51+
mcp_server_name = "test-mcp-server"
52+
mock_api_response = MagicMock()
53+
mock_api_response.json.return_value = {
54+
"displayName": "TestPrefix",
55+
"mcpServerId": (
56+
"urn:mcp:googleapis.com:projects:1234:locations:global:bigquery"
57+
),
58+
"interfaces": [{
59+
"url": "https://mcp.com",
60+
"protocolBinding": A2ATransport.jsonrpc,
61+
}],
62+
}
63+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
64+
mock_api_response
65+
)
66+
67+
registry._credentials.token = "token"
68+
registry._credentials.refresh = MagicMock()
69+
70+
mock_session = AsyncMock(spec=ClientSession)
71+
mock_create_session.return_value = mock_session
72+
73+
# Mock the tools returned by list_tools
74+
mock_session.list_tools.return_value = ListToolsResult(
75+
tools=[
76+
Tool(
77+
name="tool1",
78+
description="d1",
79+
inputs={},
80+
outputs={},
81+
inputSchema={},
82+
),
83+
Tool(
84+
name="tool2",
85+
description="d2",
86+
inputs={},
87+
outputs={},
88+
inputSchema={},
89+
),
90+
]
91+
)
92+
93+
# Act
94+
toolset = registry.get_mcp_toolset(mcp_server_name)
95+
tools = await toolset.get_tools()
96+
97+
# Assert
98+
assert isinstance(toolset, McpToolset)
99+
mock_session.list_tools.assert_called_once_with()
100+
assert len(tools) == 2
101+
for tool in tools:
102+
assert tool.custom_metadata is not None
103+
assert (
104+
tool.custom_metadata.get(GCP_MCP_SERVER_DESTINATION_ID)
105+
== "urn:mcp:googleapis.com:projects:1234:locations:global:bigquery"
106+
)
107+
108+
@pytest.mark.asyncio
109+
@patch("httpx.Client")
110+
@patch(
111+
"google.adk.tools.mcp_tool.mcp_session_manager.MCPSessionManager.create_session",
112+
new_callable=AsyncMock,
113+
)
114+
async def test_get_mcp_toolset_handles_missing_destination_id(
115+
self, mock_create_session, mock_httpx, registry
116+
):
117+
"""Test get_mcp_toolset when the destination ID is missing."""
118+
# Arrange
119+
mcp_server_name = "test-mcp-server"
120+
mock_api_response = MagicMock()
121+
mock_api_response.json.return_value = {
122+
"displayName": "TestPrefix",
123+
# "mcpServerId" is intentionally omitted
124+
"interfaces": [{
125+
"url": "https://mcp.com",
126+
"protocolBinding": A2ATransport.jsonrpc,
127+
}],
128+
}
129+
mock_httpx.return_value.__enter__.return_value.get.return_value = (
130+
mock_api_response
131+
)
132+
133+
registry._credentials.token = "token"
134+
registry._credentials.refresh = MagicMock()
135+
136+
mock_session = AsyncMock(spec=ClientSession)
137+
mock_create_session.return_value = mock_session
138+
139+
# Mock the tools returned by list_tools
140+
mock_session.list_tools.return_value = ListToolsResult(
141+
tools=[
142+
Tool(
143+
name="tool1",
144+
description="d1",
145+
inputs={},
146+
outputs={},
147+
inputSchema={},
148+
),
149+
]
150+
)
151+
152+
# Act
153+
toolset = registry.get_mcp_toolset(mcp_server_name)
154+
tools = await toolset.get_tools()
155+
156+
# Assert
157+
assert isinstance(toolset, McpToolset)
158+
mock_session.list_tools.assert_called_once_with()
159+
assert len(tools) == 1
160+
for tool in tools:
161+
# The custom_metadata shouldn't have been added
162+
assert tool.custom_metadata is None
163+
34164
def test_init_raises_value_error_if_params_missing(self):
35165
with pytest.raises(
36166
ValueError, match="project_id and location must be provided"

0 commit comments

Comments
 (0)