Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from mcp import ClientSession, ListToolsResult
from mcp.client.session import ElicitationFnT
from mcp.shared.exceptions import McpError
from mcp.shared.session import ProgressFnT
from mcp.types import (
BlobResourceContents,
ElicitationRequiredErrorData,
Expand Down Expand Up @@ -569,6 +570,7 @@ def _create_call_tool_coroutine(
arguments: dict[str, Any] | None,
read_timeout_seconds: timedelta | None,
meta: dict[str, Any] | None = None,
progress_callback: ProgressFnT | None = None,
) -> Coroutine[Any, Any, MCPCallToolResult]:
"""Create the appropriate coroutine for calling a tool.

Expand All @@ -580,6 +582,7 @@ def _create_call_tool_coroutine(
arguments: Optional arguments to pass to the tool.
read_timeout_seconds: Optional timeout for the tool call.
meta: Optional metadata to pass to the tool call per MCP spec (_meta).
progress_callback: Optional callback for receiving progress notifications from the server.

Returns:
A coroutine that will execute the tool call.
Expand All @@ -602,7 +605,7 @@ async def _call_as_task() -> MCPCallToolResult:

async def _call_tool_direct() -> MCPCallToolResult:
return await cast(ClientSession, self._background_thread_session).call_tool(
name, arguments, read_timeout_seconds, meta=meta
name, arguments, read_timeout_seconds, progress_callback=progress_callback, meta=meta
)

return _call_tool_direct()
Expand All @@ -614,6 +617,7 @@ def call_tool_sync(
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
meta: dict[str, Any] | None = None,
progress_callback: ProgressFnT | None = None,
) -> MCPToolResult:
"""Synchronously calls a tool on the MCP server.

Expand All @@ -626,6 +630,9 @@ def call_tool_sync(
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call
meta: Optional metadata to pass to the tool call per MCP spec (_meta)
progress_callback: Optional callback for receiving progress notifications from the server.
When provided, a progressToken is automatically included in the request,
enabling the server to send progress updates via ctx.report_progress().

Returns:
MCPToolResult: The result of the tool call
Expand All @@ -635,7 +642,9 @@ def call_tool_sync(
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

try:
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
coro = self._create_call_tool_coroutine(
name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback
)
call_tool_result: MCPCallToolResult = self._invoke_on_background_thread(coro).result()
return self._handle_tool_result(tool_use_id, call_tool_result)
except Exception as e:
Expand All @@ -649,6 +658,7 @@ async def call_tool_async(
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
meta: dict[str, Any] | None = None,
progress_callback: ProgressFnT | None = None,
) -> MCPToolResult:
"""Asynchronously calls a tool on the MCP server.

Expand All @@ -661,6 +671,9 @@ async def call_tool_async(
arguments: Optional arguments to pass to the tool
read_timeout_seconds: Optional timeout for the tool call
meta: Optional metadata to pass to the tool call per MCP spec (_meta)
progress_callback: Optional callback for receiving progress notifications from the server.
When provided, a progressToken is automatically included in the request,
enabling the server to send progress updates via ctx.report_progress().

Returns:
MCPToolResult: The result of the tool call
Expand All @@ -670,7 +683,9 @@ async def call_tool_async(
raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE)

try:
coro = self._create_call_tool_coroutine(name, arguments, read_timeout_seconds, meta=meta)
coro = self._create_call_tool_coroutine(
name, arguments, read_timeout_seconds, meta=meta, progress_callback=progress_callback
)
future = self._invoke_on_background_thread(coro)
call_tool_result: MCPCallToolResult = await asyncio.wrap_future(future)
return self._handle_tool_result(tool_use_id, call_tool_result)
Expand Down
79 changes: 69 additions & 10 deletions tests/strands/tools/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_call_tool_sync_status(mock_transport, mock_session, is_error, expected_
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None)

assert result["status"] == expected_status
assert result["toolUseId"] == "test-123"
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_call_tool_sync_with_structured_content(mock_transport, mock_session):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None)

assert result["status"] == "success"
assert result["toolUseId"] == "test-123"
Expand Down Expand Up @@ -191,7 +191,9 @@ def test_call_tool_sync_forwards_meta(mock_transport, mock_session):
tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, meta=meta
)

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=meta)
mock_session.call_tool.assert_called_once_with(
"test_tool", {"param": "value"}, None, progress_callback=None, meta=meta
)
assert result["status"] == "success"


Expand Down Expand Up @@ -225,6 +227,63 @@ async def mock_awaitable():
assert result["status"] == "success"


def test_call_tool_sync_forwards_progress_callback(mock_transport, mock_session):
"""Test that call_tool_sync forwards progress_callback to ClientSession.call_tool."""
mock_content = MCPTextContent(type="text", text="Test message")
mock_session.call_tool.return_value = MCPCallToolResult(isError=False, content=[mock_content])

async def on_progress(progress: float, total: float | None, message: str | None) -> None:
pass

with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(
tool_use_id="test-123",
name="test_tool",
arguments={"param": "value"},
progress_callback=on_progress,
)

mock_session.call_tool.assert_called_once_with(
"test_tool", {"param": "value"}, None, progress_callback=on_progress, meta=None
)
assert result["status"] == "success"


@pytest.mark.asyncio
async def test_call_tool_async_forwards_progress_callback(mock_transport, mock_session):
"""Test that call_tool_async forwards progress_callback to ClientSession.call_tool."""
mock_content = MCPTextContent(type="text", text="Test message")
mock_result = MCPCallToolResult(isError=False, content=[mock_content])
mock_session.call_tool.return_value = mock_result

async def on_progress(progress: float, total: float | None, message: str | None) -> None:
pass

with MCPClient(mock_transport["transport_callable"]) as client:
with (
patch("asyncio.run_coroutine_threadsafe") as mock_run_coroutine_threadsafe,
patch("asyncio.wrap_future") as mock_wrap_future,
):
mock_future = MagicMock()
mock_run_coroutine_threadsafe.return_value = mock_future

async def mock_awaitable():
return mock_result

mock_wrap_future.return_value = mock_awaitable()

result = await client.call_tool_async(
tool_use_id="test-123",
name="test_tool",
arguments={"param": "value"},
progress_callback=on_progress,
)

mock_run_coroutine_threadsafe.assert_called_once()

assert result["status"] == "success"


@pytest.mark.asyncio
@pytest.mark.parametrize("is_error,expected_status", [(False, "success"), (True, "error")])
async def test_call_tool_async_status(mock_transport, mock_session, is_error, expected_status):
Expand Down Expand Up @@ -629,7 +688,7 @@ def test_call_tool_sync_embedded_nested_text(mock_transport, mock_session):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-text", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert result["content"][0]["text"] == "inner text"
Expand All @@ -654,7 +713,7 @@ def test_call_tool_sync_embedded_nested_base64_textual_mime(mock_transport, mock
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-blob", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert result["content"][0]["text"] == '{"k":"v"}'
Expand All @@ -680,7 +739,7 @@ def test_call_tool_sync_embedded_image_blob(mock_transport, mock_session):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-image", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert "image" in result["content"][0]
Expand All @@ -705,7 +764,7 @@ def test_call_tool_sync_embedded_non_textual_blob_dropped(mock_transport, mock_s
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-binary", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 0 # Content should be dropped

Expand All @@ -728,7 +787,7 @@ def test_call_tool_sync_embedded_multiple_textual_mimes(mock_transport, mock_ses
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-yaml", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 1
assert "key: value" in result["content"][0]["text"]
Expand All @@ -755,7 +814,7 @@ def __init__(self):
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="er-unknown", name="get_file_contents", arguments={})

mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, meta=None)
mock_session.call_tool.assert_called_once_with("get_file_contents", {}, None, progress_callback=None, meta=None)
assert result["status"] == "success"
assert len(result["content"]) == 0 # Unknown resource type should be dropped

Expand Down Expand Up @@ -807,7 +866,7 @@ def test_call_tool_sync_with_meta_and_structured_content(mock_transport, mock_se
with MCPClient(mock_transport["transport_callable"]) as client:
result = client.call_tool_sync(tool_use_id="test-123", name="test_tool", arguments={"param": "value"})

mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, meta=None)
mock_session.call_tool.assert_called_once_with("test_tool", {"param": "value"}, None, progress_callback=None, meta=None)

assert result["status"] == "success"
assert result["toolUseId"] == "test-123"
Expand Down