diff --git a/site/src/content/docs/user-guide/concepts/tools/mcp-tools.mdx b/site/src/content/docs/user-guide/concepts/tools/mcp-tools.mdx index 7fc05b94a..f02e421ce 100644 --- a/site/src/content/docs/user-guide/concepts/tools/mcp-tools.mdx +++ b/site/src/content/docs/user-guide/concepts/tools/mcp-tools.mdx @@ -503,6 +503,62 @@ Pass an `elicitationCallback` when constructing the client. The callback receive For more information on elicitation, see the [MCP specification](https://modelcontextprotocol.io/specification/draft/client/elicitation). +### Progress Notifications + +MCP servers can report incremental progress during long-running tool calls. Configure a `progress_callback` on the client to receive these updates: + + + + +```python +from mcp import stdio_client, StdioServerParameters +from strands import Agent +from strands.tools.mcp import MCPClient + +async def progress_callback(progress, total, message): + pct = f"{progress}/{total}" if total is not None else str(progress) + label = f" — {message}" if message else "" + print(f"Progress: {pct}{label}") + +client = MCPClient( + lambda: stdio_client( + StdioServerParameters(command="python", args=["/path/to/server.py"]) + ), + progress_callback=progress_callback, +) + +with client: + agent = Agent(tools=client.list_tools_sync()) + agent("Run the long-running task") +``` + +The callback receives three arguments: + +| Argument | Type | Description | +|---|---|---| +| `progress` | `float` | Current progress value reported by the server | +| `total` | `float \| None` | Total value (may be `None` if the server doesn't report it) | +| `message` | `str \| None` | Optional human-readable status message from the server | + +You can also pass a `progress_callback` directly to `call_tool_sync` or `call_tool_async` to override the instance-level callback for a single call: + +```python +result = client.call_tool_sync( + tool_use_id="tool-123", + name="long_running_tool", + arguments={"input": "data"}, + progress_callback=my_one_off_callback, +) +``` + + + + +Progress notifications are not yet supported in the TypeScript SDK. + + + + ## Best Practices - **Tool Descriptions**: Provide clear descriptions for tools to help the agent understand when and how to use them diff --git a/strands-py/src/strands/tools/mcp/mcp_client.py b/strands-py/src/strands/tools/mcp/mcp_client.py index def67cdc9..8eadc6f87 100644 --- a/strands-py/src/strands/tools/mcp/mcp_client.py +++ b/strands-py/src/strands/tools/mcp/mcp_client.py @@ -27,6 +27,7 @@ from mcp import ClientSession, ListToolsResult from mcp.client.session import ElicitationFnT from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT from mcp.types import ( BlobResourceContents, ElicitationRequiredErrorData, @@ -121,6 +122,7 @@ def __init__( tool_filters: ToolFilters | None = None, prefix: str | None = None, elicitation_callback: ElicitationFnT | None = None, + progress_callback: ProgressFnT | None = None, tasks_config: TasksConfig | None = None, ) -> None: """Initialize a new MCP Server connection. @@ -132,6 +134,9 @@ def __init__( tool_filters: Optional filters to apply to tools. prefix: Optional prefix for tool names. elicitation_callback: Optional callback function to handle elicitation requests from the MCP server. + progress_callback: Optional callback to receive progress notifications during tool execution. + Called with ``(progress, total, message)`` as the server reports progress. The ``total`` + and ``message`` parameters may be ``None`` if the server does not provide them. tasks_config: Configuration for MCP task-augmented execution for long-running tools. If provided (not None), enables task-augmented execution for tools that support it. See TasksConfig for details. This feature is experimental and subject to change. @@ -140,6 +145,7 @@ def __init__( self._tool_filters = tool_filters self._prefix = prefix self._elicitation_callback = elicitation_callback + self._progress_callback = progress_callback mcp_instrumentation() self._session_id = uuid.uuid4() @@ -589,6 +595,7 @@ def _create_call_tool_coroutine( arguments: dict[str, Any] | None, read_timeout_seconds: timedelta | None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> Coroutine[Any, Any, MCPCallToolResult]: """Create the appropriate coroutine for calling a tool. @@ -600,14 +607,22 @@ def _create_call_tool_coroutine( arguments: Optional arguments to pass to the tool. read_timeout_seconds: Optional timeout for the tool call. meta: Optional metadata to pass to the tool call per MCP spec (_meta). + progress_callback: Optional callback to receive progress notifications. + If None, falls back to the instance-level callback set at construction time. Returns: A coroutine that will execute the tool call. """ use_task = self._should_use_task(name) + effective_callback = progress_callback if progress_callback is not None else self._progress_callback if use_task: self._log_debug_with_thread("tool=<%s> | using task-augmented execution", name) + if effective_callback is not None: + logger.warning( + "tool=<%s> | progress callbacks are ignored when task-augmented execution is enabled", + name, + ) async def _call_as_task() -> MCPCallToolResult: # When task-augmented execution is used, use the read_timeout_seconds parameter @@ -622,7 +637,7 @@ async def _call_as_task() -> MCPCallToolResult: async def _call_tool_direct() -> MCPCallToolResult: return await cast(ClientSession, self._background_thread_session).call_tool( - name, arguments, read_timeout_seconds, meta=meta + name, arguments, read_timeout_seconds, progress_callback=effective_callback, meta=meta ) return _call_tool_direct() @@ -634,6 +649,7 @@ def call_tool_sync( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> MCPToolResult: """Synchronously calls a tool on the MCP server. @@ -646,6 +662,8 @@ def call_tool_sync( arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call meta: Optional metadata to pass to the tool call per MCP spec (_meta) + progress_callback: Optional callback to receive progress notifications for this + call. Overrides the instance-level callback set at construction time. Returns: MCPToolResult: The result of the tool call @@ -655,7 +673,9 @@ def call_tool_sync( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + coro = self._create_call_tool_coroutine( + name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback + ) call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result() return self._handle_tool_result(tool_use_id, call_tool_result) except Exception as e: @@ -669,6 +689,7 @@ async def call_tool_async( arguments: dict[str, Any] | None = None, read_timeout_seconds: timedelta | None = None, meta: dict[str, Any] | None = None, + progress_callback: ProgressFnT | None = None, ) -> MCPToolResult: """Asynchronously calls a tool on the MCP server. @@ -681,6 +702,8 @@ async def call_tool_async( arguments: Optional arguments to pass to the tool read_timeout_seconds: Optional timeout for the tool call meta: Optional metadata to pass to the tool call per MCP spec (_meta) + progress_callback: Optional callback to receive progress notifications for this + call. Overrides the instance-level callback set at construction time. Returns: MCPToolResult: The result of the tool call @@ -690,7 +713,9 @@ async def call_tool_async( raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) try: - coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta) + coro = self._create_call_tool_coroutine( + name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback + ) future = self._invoke_on_background_thread(coro) call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future) return self._handle_tool_result(tool_use_id, call_tool_result) @@ -948,7 +973,6 @@ def map_mcp_content_to_tool_result_content( self._log_debug_with_thread("unhandled content type: %s - dropping content", content.__class__.__name__) return None - def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None: """Logger helper to help differentiate logs coming from MCPClient background thread.""" formatted_msg = msg % args if args else msg diff --git a/strands-py/tests/strands/models/test_gemini.py b/strands-py/tests/strands/models/test_gemini.py index c17423742..ac04b20eb 100644 --- a/strands-py/tests/strands/models/test_gemini.py +++ b/strands-py/tests/strands/models/test_gemini.py @@ -872,9 +872,7 @@ async def test_stream_response_max_tokens(gemini_client, model, messages, agener @pytest.mark.asyncio -async def test_stream_response_safety_block_with_missing_counts( - gemini_client, model, messages, agenerator, alist -): +async def test_stream_response_safety_block_with_missing_counts(gemini_client, model, messages, agenerator, alist): gemini_client.aio.models.generate_content_stream.return_value = agenerator( [ genai.types.GenerateContentResponse( diff --git a/strands-py/tests/strands/tools/mcp/test_mcp_client.py b/strands-py/tests/strands/tools/mcp/test_mcp_client.py index 958ebcf72..b22f2cef8 100644 --- a/strands-py/tests/strands/tools/mcp/test_mcp_client.py +++ b/strands-py/tests/strands/tools/mcp/test_mcp_client.py @@ -1,6 +1,6 @@ import base64 import time -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from mcp import ListToolsResult @@ -124,7 +124,9 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_ with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with( + "test_tool", {"param": "value"}, None, progress_callback=None, meta=None + ) assert result["status"] == expected_status assert result["toolUseId"] == "test-123" @@ -155,7 +157,9 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with( + "test_tool", {"param": "value"}, None, progress_callback=None, meta=None + ) assert result["status"] == "success" assert result["toolUseId"] == "test-123" @@ -193,10 +197,52 @@ def test_call_tool_sync_forwards_meta(mock_transport, mock_session): tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta ) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta) + mock_session.call_tool.assert_called_once_with( + "test_tool", {"param": "value"}, None, progress_callback=None, meta=meta + ) assert result["status"] == "success" +def test_call_tool_sync_forwards_instance_progress_callback(mock_transport, mock_session): + """Test that call_tool_sync uses the instance-level progress callback when no per-call callback is given.""" + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + cb = AsyncMock() + + with MCPClient(mock_transport["transport_callable"], progress_callback=cb) as client: + result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + + mock_session.call_tool.assert_called_once_with("test_tool", {}, None, progress_callback=cb, meta=None) + assert result["status"] == "success" + + +def test_call_tool_sync_per_call_progress_callback_overrides_instance(mock_transport, mock_session): + """Test that a per-call progress callback overrides the instance-level one.""" + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + instance_cb = AsyncMock() + per_call_cb = AsyncMock() + + with MCPClient(mock_transport["transport_callable"], progress_callback=instance_cb) as client: + result = client.call_tool_sync( + tool_use_id="test-123", name="test_tool", arguments={}, progress_callback=per_call_cb + ) + + mock_session.call_tool.assert_called_once_with("test_tool", {}, None, progress_callback=per_call_cb, meta=None) + assert result["status"] == "success" + + +def test_call_tool_sync_no_progress_callback_by_default(mock_transport, mock_session): + """Test that progress_callback defaults to None when not set on instance or per-call.""" + mock_content = MCPTextContent(type="text", text="done") + mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content]) + + with MCPClient(mock_transport["transport_callable"]) as client: + client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={}) + + mock_session.call_tool.assert_called_once_with("test_tool", {}, None, progress_callback=None, meta=None) + + @pytest.mark.asyncio async def test_call_tool_async_forwards_meta(mock_transport, mock_session): """Test that call_tool_async forwards meta to ClientSession.call_tool.""" @@ -672,7 +718,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == "inner text" @@ -697,7 +743,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert result["content"][0]["text"] == '{"k":"v"}' @@ -723,7 +769,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "image" in result["content"][0] @@ -748,7 +794,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Content should be dropped @@ -771,7 +817,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 1 assert "key: value" in result["content"][0]["text"] @@ -798,7 +844,7 @@ def __init__(self): with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={}) - mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None) + mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None) assert result["status"] == "success" assert len(result["content"]) == 0 # Unknown resource type should be dropped @@ -850,7 +896,9 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se with MCPClient(mock_transport["transport_callable"]) as client: result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"}) - mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None) + mock_session.call_tool.assert_called_once_with( + "test_tool", {"param": "value"}, None, progress_callback=None, meta=None + ) assert result["status"] == "success" assert result["toolUseId"] == "test-123" diff --git a/strands-py/tests/strands/tools/mcp/test_mcp_client_tasks.py b/strands-py/tests/strands/tools/mcp/test_mcp_client_tasks.py index d566ac6f5..e29642424 100644 --- a/strands-py/tests/strands/tools/mcp/test_mcp_client_tasks.py +++ b/strands-py/tests/strands/tools/mcp/test_mcp_client_tasks.py @@ -215,6 +215,25 @@ async def poll(task_id): assert result["status"] == "success" assert "Done" in result["content"][0].get("text", "") + def test_logs_warning_when_task_execution_ignores_progress_callback(self, mock_transport, mock_session, caplog): + """Test warning is logged when task execution ignores progress callbacks.""" + self._setup_task_tool(mock_session, "task_tool") + + def callback(progress: float, total: float | None, message: str | None) -> None: + _ = (progress, total, message) + + with MCPClient(mock_transport["transport_callable"], tasks_config=TasksConfig()) as client: + client.list_tools_sync() + with caplog.at_level("WARNING", logger="strands.tools.mcp.mcp_client"): + client.call_tool_sync( + tool_use_id="test-id", + name="task_tool", + arguments={}, + progress_callback=callback, + ) + + assert "progress callbacks are ignored when task-augmented execution is enabled" in caplog.text + class TestTaskMetaForwarding: """Tests for meta parameter forwarding in task-augmented execution."""