Skip to content

Commit a6e259f

Browse files
feat(mcp): forward progress_callback to ClientSession.call_tool
Add progress_callback parameter to call_tool_sync, call_tool_async, and _create_call_tool_coroutine. When provided, the MCP SDK automatically includes a progressToken in the request metadata, enabling MCP servers to send progress notifications via ctx.report_progress(). This is essential for long-running tool calls (e.g. crawls, agent orchestration) where progress data keeps the HTTP/SSE connection alive and provides observability. Closes #1812
1 parent 50b2c79 commit a6e259f

2 files changed

Lines changed: 87 additions & 13 deletions

File tree

src/strands/tools/mcp/mcp_client.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from mcp import ClientSession, ListToolsResult
2727
from mcp.client.session import ElicitationFnT
2828
from mcp.shared.exceptions import McpError
29+
from mcp.shared.session import ProgressFnT
2930
from mcp.types import (
3031
BlobResourceContents,
3132
ElicitationRequiredErrorData,
@@ -569,6 +570,7 @@ def _create_call_tool_coroutine(
569570
arguments: dict[str, Any] | None,
570571
read_timeout_seconds: timedelta | None,
571572
meta: dict[str, Any] | None = None,
573+
progress_callback: ProgressFnT | None = None,
572574
) -> Coroutine[Any, Any, MCPCallToolResult]:
573575
"""Create the appropriate coroutine for calling a tool.
574576
@@ -580,6 +582,7 @@ def _create_call_tool_coroutine(
580582
arguments: Optional arguments to pass to the tool.
581583
read_timeout_seconds: Optional timeout for the tool call.
582584
meta: Optional metadata to pass to the tool call per MCP spec (_meta).
585+
progress_callback: Optional callback for receiving progress notifications from the server.
583586
584587
Returns:
585588
A coroutine that will execute the tool call.
@@ -602,7 +605,7 @@ async def _call_as_task() -> MCPCallToolResult:
602605

603606
async def _call_tool_direct() -> MCPCallToolResult:
604607
return await cast(ClientSession, self._background_thread_session).call_tool(
605-
name, arguments, read_timeout_seconds, meta=meta
608+
name, arguments, read_timeout_seconds, progress_callback=progress_callback, meta=meta
606609
)
607610

608611
return _call_tool_direct()
@@ -614,6 +617,7 @@ def call_tool_sync(
614617
arguments: dict[str, Any] | None = None,
615618
read_timeout_seconds: timedelta | None = None,
616619
meta: dict[str, Any] | None = None,
620+
progress_callback: ProgressFnT | None = None,
617621
) -> MCPToolResult:
618622
"""Synchronously calls a tool on the MCP server.
619623
@@ -626,6 +630,9 @@ def call_tool_sync(
626630
arguments: Optional arguments to pass to the tool
627631
read_timeout_seconds: Optional timeout for the tool call
628632
meta: Optional metadata to pass to the tool call per MCP spec (_meta)
633+
progress_callback: Optional callback for receiving progress notifications from the server.
634+
When provided, a progressToken is automatically included in the request,
635+
enabling the server to send progress updates via ctx.report_progress().
629636
630637
Returns:
631638
MCPToolResult: The result of the tool call
@@ -635,7 +642,9 @@ def call_tool_sync(
635642
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
636643

637644
try:
638-
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
645+
coro = self._create_call_tool_coroutine(
646+
name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback
647+
)
639648
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result()
640649
return self._handle_tool_result(tool_use_id, call_tool_result)
641650
except Exception as e:
@@ -649,6 +658,7 @@ async def call_tool_async(
649658
arguments: dict[str, Any] | None = None,
650659
read_timeout_seconds: timedelta | None = None,
651660
meta: dict[str, Any] | None = None,
661+
progress_callback: ProgressFnT | None = None,
652662
) -> MCPToolResult:
653663
"""Asynchronously calls a tool on the MCP server.
654664
@@ -661,6 +671,9 @@ async def call_tool_async(
661671
arguments: Optional arguments to pass to the tool
662672
read_timeout_seconds: Optional timeout for the tool call
663673
meta: Optional metadata to pass to the tool call per MCP spec (_meta)
674+
progress_callback: Optional callback for receiving progress notifications from the server.
675+
When provided, a progressToken is automatically included in the request,
676+
enabling the server to send progress updates via ctx.report_progress().
664677
665678
Returns:
666679
MCPToolResult: The result of the tool call
@@ -670,7 +683,9 @@ async def call_tool_async(
670683
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)
671684

672685
try:
673-
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
686+
coro = self._create_call_tool_coroutine(
687+
name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback
688+
)
674689
future = self._invoke_on_background_thread(coro)
675690
call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future)
676691
return self._handle_tool_result(tool_use_id, call_tool_result)

tests/strands/tools/mcp/test_mcp_client.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_
124124
with MCPClient(mock_transport["transport_callable"]) as client:
125125
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
126126

127-
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
127+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None)
128128

129129
assert result["status"] == expected_status
130130
assert result["toolUseId"] == "test-123"
@@ -153,7 +153,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session):
153153
with MCPClient(mock_transport["transport_callable"]) as client:
154154
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
155155

156-
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
156+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None)
157157

158158
assert result["status"] == "success"
159159
assert result["toolUseId"] == "test-123"
@@ -191,7 +191,9 @@ def test_call_tool_sync_forwards_meta(mock_transport, mock_session):
191191
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta
192192
)
193193

194-
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta)
194+
mock_session.call_tool.assert_called_once_with(
195+
"test_tool", {"param": "value"}, None, progress_callback=None, meta=meta
196+
)
195197
assert result["status"] == "success"
196198

197199

@@ -225,6 +227,63 @@ async def mock_awaitable():
225227
assert result["status"] == "success"
226228

227229

230+
def test_call_tool_sync_forwards_progress_callback(mock_transport, mock_session):
231+
"""Test that call_tool_sync forwards progress_callback to ClientSession.call_tool."""
232+
mock_content = MCPTextContent(type="text", text="Test message")
233+
mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content])
234+
235+
async def on_progress(progress: float, total: float | None, message: str | None) -> None:
236+
pass
237+
238+
with MCPClient(mock_transport["transport_callable"]) as client:
239+
result = client.call_tool_sync(
240+
tool_use_id="test-123",
241+
name="test_tool",
242+
arguments={"param": "value"},
243+
progress_callback=on_progress,
244+
)
245+
246+
mock_session.call_tool.assert_called_once_with(
247+
"test_tool", {"param": "value"}, None, progress_callback=on_progress, meta=None
248+
)
249+
assert result["status"] == "success"
250+
251+
252+
@pytest.mark.asyncio
253+
async def test_call_tool_async_forwards_progress_callback(mock_transport, mock_session):
254+
"""Test that call_tool_async forwards progress_callback to ClientSession.call_tool."""
255+
mock_content = MCPTextContent(type="text", text="Test message")
256+
mock_result = MCPCallToolResult(isError=False, content=[mock_content])
257+
mock_session.call_tool.return_value = mock_result
258+
259+
async def on_progress(progress: float, total: float | None, message: str | None) -> None:
260+
pass
261+
262+
with MCPClient(mock_transport["transport_callable"]) as client:
263+
with (
264+
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
265+
patch("asyncio.wrap_future") as mock_wrap_future,
266+
):
267+
mock_future = MagicMock()
268+
mock_run_coroutine_threadsafe.return_value = mock_future
269+
270+
async def mock_awaitable():
271+
return mock_result
272+
273+
mock_wrap_future.return_value = mock_awaitable()
274+
275+
result = await client.call_tool_async(
276+
tool_use_id="test-123",
277+
name="test_tool",
278+
arguments={"param": "value"},
279+
progress_callback=on_progress,
280+
)
281+
282+
mock_run_coroutine_threadsafe.assert_called_once()
283+
284+
assert result["status"] == "success"
285+
286+
228287
@pytest.mark.asyncio
229288
@pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")])
230289
async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status):
@@ -629,7 +688,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session):
629688
with MCPClient(mock_transport["transport_callable"]) as client:
630689
result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={})
631690

632-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
691+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
633692
assert result["status"] == "success"
634693
assert len(result["content"]) == 1
635694
assert result["content"][0]["text"] == "inner text"
@@ -654,7 +713,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock
654713
with MCPClient(mock_transport["transport_callable"]) as client:
655714
result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={})
656715

657-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
716+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
658717
assert result["status"] == "success"
659718
assert len(result["content"]) == 1
660719
assert result["content"][0]["text"] == '{"k":"v"}'
@@ -680,7 +739,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session):
680739
with MCPClient(mock_transport["transport_callable"]) as client:
681740
result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={})
682741

683-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
742+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
684743
assert result["status"] == "success"
685744
assert len(result["content"]) == 1
686745
assert "image" in result["content"][0]
@@ -705,7 +764,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s
705764
with MCPClient(mock_transport["transport_callable"]) as client:
706765
result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={})
707766

708-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
767+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
709768
assert result["status"] == "success"
710769
assert len(result["content"]) == 0 # Content should be dropped
711770

@@ -728,7 +787,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses
728787
with MCPClient(mock_transport["transport_callable"]) as client:
729788
result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={})
730789

731-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
790+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
732791
assert result["status"] == "success"
733792
assert len(result["content"]) == 1
734793
assert "key: value" in result["content"][0]["text"]
@@ -755,7 +814,7 @@ def __init__(self):
755814
with MCPClient(mock_transport["transport_callable"]) as client:
756815
result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={})
757816

758-
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
817+
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
759818
assert result["status"] == "success"
760819
assert len(result["content"]) == 0 # Unknown resource type should be dropped
761820

@@ -807,7 +866,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se
807866
with MCPClient(mock_transport["transport_callable"]) as client:
808867
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})
809868

810-
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
869+
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None)
811870

812871
assert result["status"] == "success"
813872
assert result["toolUseId"] == "test-123"

0 commit comments

Comments
 (0)