|
17 | 17 | import sys |
18 | 18 | import unittest |
19 | 19 | from unittest.mock import AsyncMock |
| 20 | +from unittest.mock import MagicMock |
20 | 21 | from unittest.mock import Mock |
21 | 22 | from unittest.mock import patch |
22 | 23 |
|
|
28 | 29 | from google.adk.tools.mcp_tool.mcp_session_manager import StreamableHTTPConnectionParams |
29 | 30 | from google.adk.tools.mcp_tool.mcp_tool import MCPTool |
30 | 31 | from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset |
| 32 | +from google.adk.tools.mcp_tool.mcp_toolset import McpToolset |
31 | 33 | from mcp import StdioServerParameters |
32 | 34 | import pytest |
33 | 35 |
|
@@ -302,3 +304,52 @@ async def test_get_tools_retry_decorator(self): |
302 | 304 |
|
303 | 305 | # Check that the method has the retry decorator |
304 | 306 | assert hasattr(toolset.get_tools, "__wrapped__") |
| 307 | + |
| 308 | + @pytest.mark.asyncio |
| 309 | + async def test_mcp_toolset_with_prefix(self): |
| 310 | + """Test that McpToolset correctly applies the tool_name_prefix.""" |
| 311 | + # Mock the connection parameters |
| 312 | + mock_connection_params = MagicMock() |
| 313 | + mock_connection_params.timeout = None |
| 314 | + |
| 315 | + # Mock the MCPSessionManager and its create_session method |
| 316 | + mock_session_manager = MagicMock() |
| 317 | + mock_session = MagicMock() |
| 318 | + |
| 319 | + # Mock the list_tools response from the MCP server |
| 320 | + mock_tool1 = MagicMock() |
| 321 | + mock_tool1.name = "tool1" |
| 322 | + mock_tool1.description = "tool 1 desc" |
| 323 | + mock_tool2 = MagicMock() |
| 324 | + mock_tool2.name = "tool2" |
| 325 | + mock_tool2.description = "tool 2 desc" |
| 326 | + list_tools_result = MagicMock() |
| 327 | + list_tools_result.tools = [mock_tool1, mock_tool2] |
| 328 | + mock_session.list_tools = AsyncMock(return_value=list_tools_result) |
| 329 | + mock_session_manager.create_session = AsyncMock(return_value=mock_session) |
| 330 | + |
| 331 | + # Create an instance of McpToolset with a prefix |
| 332 | + toolset = McpToolset( |
| 333 | + connection_params=mock_connection_params, |
| 334 | + tool_name_prefix="my_prefix", |
| 335 | + ) |
| 336 | + |
| 337 | + # Replace the internal session manager with our mock |
| 338 | + toolset._mcp_session_manager = mock_session_manager |
| 339 | + |
| 340 | + # Get the tools from the toolset |
| 341 | + tools = await toolset.get_tools() |
| 342 | + |
| 343 | + # The get_tools method in McpToolset returns MCPTool objects, which are |
| 344 | + # instances of BaseTool. The prefixing is handled by the BaseToolset, |
| 345 | + # so we need to call get_tools_with_prefix to get the prefixed tools. |
| 346 | + prefixed_tools = await toolset.get_tools_with_prefix() |
| 347 | + |
| 348 | + # Assert that the tools are prefixed correctly |
| 349 | + assert len(prefixed_tools) == 2 |
| 350 | + assert prefixed_tools[0].name == "my_prefix_tool1" |
| 351 | + assert prefixed_tools[1].name == "my_prefix_tool2" |
| 352 | + |
| 353 | + # Assert that the original tools are not modified |
| 354 | + assert tools[0].name == "tool1" |
| 355 | + assert tools[1].name == "tool2" |
0 commit comments