Skip to content

Commit 8f82697

Browse files
Piyushmryawukath
authored andcommitted
feat(mcp): add sampling callback support for MCP sessions
Merge #4718 ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Closes: N/A - Related: N/A **2. Or, if no issue exists, describe the change:** **Problem** ADK’s MCP integration currently does not expose the MCP sampling callback capability. This prevents agent-side LLM sampling handlers from being used when interacting with MCP servers that support sampling. The MCP Python SDK supports sampling callbacks, but these parameters are not propagated through the ADK MCP integration layers. **Solution** Add sampling callback support by propagating the parameters through the MCP stack: - Add `sampling_callback` and `sampling_capabilities` parameters to `McpToolset` - Forward them to `MCPSessionManager` - Forward them to `SessionContext` - Pass them into `ClientSession` initialization This enables agent-side sampling handling when interacting with MCP servers. --- ### Testing Plan **Unit Tests** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. Added `test_mcp_sampling_callback.py` to verify that the sampling callback is correctly invoked. Example result: pytest tests/unittests/tools/mcp_tool/test_mcp_sampling_callback.py 1 passed **Manual End-to-End (E2E) Tests** Manual testing was performed using a FastMCP sampling example server where the sampling callback was invoked from the agent side and returned the expected response. --- ### Checklist - [x] I have read the CONTRIBUTING.md document. - [x] I have performed a self-review of my own code. - [x] I have commented my code where necessary. - [x] I have added tests proving the feature works. - [x] Unit tests pass locally. - [x] I have manually tested the change end-to-end. --- ### Additional context This change aligns ADK MCP support with the sampling capabilities available in the MCP Python SDK and enables agent implementations to handle sampling requests via a callback. Co-authored-by: Kathy Wu <wukathy@google.com> COPYBARA_INTEGRATE_REVIEW=#4718 from Piyushmrya:fix-mcp-sampling-callback 18f477f PiperOrigin-RevId: 883401178
1 parent faafac9 commit 8f82697

File tree

4 files changed

+73
-0
lines changed

4 files changed

+73
-0
lines changed

src/google/adk/tools/mcp_tool/mcp_session_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
from typing import Union
3434

3535
from mcp import ClientSession
36+
from mcp import SamplingCapability
3637
from mcp import StdioServerParameters
38+
from mcp.client.session import SamplingFnT
3739
from mcp.client.sse import sse_client
3840
from mcp.client.stdio import stdio_client
3941
from mcp.client.streamable_http import create_mcp_http_client
@@ -195,6 +197,9 @@ def __init__(
195197
StreamableHTTPConnectionParams,
196198
],
197199
errlog: TextIO = sys.stderr,
200+
*,
201+
sampling_callback: Optional[SamplingFnT] = None,
202+
sampling_capabilities: Optional[SamplingCapability] = None,
198203
):
199204
"""Initializes the MCP session manager.
200205
@@ -204,7 +209,13 @@ def __init__(
204209
parameters but it's not configurable for now.
205210
errlog: (Optional) TextIO stream for error logging. Use only for
206211
initializing a local stdio MCP session.
212+
sampling_callback: Optional callback to handle sampling requests from the
213+
MCP server.
214+
sampling_capabilities: Optional capabilities for sampling.
207215
"""
216+
self._sampling_callback = sampling_callback
217+
self._sampling_capabilities = sampling_capabilities
218+
208219
if isinstance(connection_params, StdioServerParameters):
209220
# So far timeout is not configurable. Given MCP is still evolving, we
210221
# would expect stdio_client to evolve to accept timeout parameter like
@@ -475,6 +486,8 @@ async def create_session(
475486
timeout=timeout_in_seconds,
476487
sse_read_timeout=sse_read_timeout_in_seconds,
477488
is_stdio=is_stdio,
489+
sampling_callback=self._sampling_callback,
490+
sampling_capabilities=self._sampling_capabilities,
478491
)
479492
),
480493
timeout=timeout_in_seconds,

src/google/adk/tools/mcp_tool/mcp_toolset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
from typing import Union
3030
import warnings
3131

32+
from mcp import SamplingCapability
3233
from mcp import StdioServerParameters
34+
from mcp.client.session import SamplingFnT
3335
from mcp.shared.session import ProgressFnT
3436
from mcp.types import ListResourcesResult
3537
from mcp.types import ListToolsResult
@@ -114,6 +116,8 @@ def __init__(
114116
Union[ProgressFnT, ProgressCallbackFactory]
115117
] = None,
116118
use_mcp_resources: Optional[bool] = False,
119+
sampling_callback: Optional[SamplingFnT] = None,
120+
sampling_capabilities: Optional[SamplingCapability] = None,
117121
):
118122
"""Initializes the McpToolset.
119123
@@ -150,10 +154,16 @@ def __init__(
150154
use_mcp_resources: Whether the agent should have access to MCP resources.
151155
This will add a `load_mcp_resource` tool to the toolset and include
152156
available resources in the agent context. Defaults to False.
157+
sampling_callback: Optional callback to handle sampling requests from the
158+
MCP server.
159+
sampling_capabilities: Optional capabilities for sampling.
153160
"""
154161

155162
super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix)
156163

164+
self._sampling_callback = sampling_callback
165+
self._sampling_capabilities = sampling_capabilities
166+
157167
if not connection_params:
158168
raise ValueError("Missing connection params in McpToolset.")
159169

@@ -166,6 +176,8 @@ def __init__(
166176
self._mcp_session_manager = MCPSessionManager(
167177
connection_params=self._connection_params,
168178
errlog=self._errlog,
179+
sampling_callback=self._sampling_callback,
180+
sampling_capabilities=self._sampling_capabilities,
169181
)
170182
self._auth_scheme = auth_scheme
171183
self._auth_credential = auth_credential

src/google/adk/tools/mcp_tool/session_context.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from typing import Optional
2323

2424
from mcp import ClientSession
25+
from mcp import SamplingCapability
26+
from mcp.client.session import SamplingFnT
2527

2628
logger = logging.getLogger('google_adk.' + __name__)
2729

@@ -54,6 +56,9 @@ def __init__(
5456
timeout: Optional[float],
5557
sse_read_timeout: Optional[float],
5658
is_stdio: bool = False,
59+
*,
60+
sampling_callback: Optional[SamplingFnT] = None,
61+
sampling_capabilities: Optional[SamplingCapability] = None,
5762
):
5863
"""
5964
Args:
@@ -63,6 +68,9 @@ def __init__(
6368
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
6469
server.
6570
is_stdio: Whether this is a stdio connection (affects read timeout).
71+
sampling_callback: Optional callback to handle sampling requests from the
72+
MCP server.
73+
sampling_capabilities: Optional capabilities for sampling.
6674
"""
6775
self._client = client
6876
self._timeout = timeout
@@ -73,6 +81,8 @@ def __init__(
7381
self._close_event = asyncio.Event()
7482
self._task: Optional[asyncio.Task] = None
7583
self._task_lock = asyncio.Lock()
84+
self._sampling_callback = sampling_callback
85+
self._sampling_capabilities = sampling_capabilities
7686

7787
@property
7888
def session(self) -> Optional[ClientSession]:
@@ -165,6 +175,8 @@ async def _run(self):
165175
read_timeout_seconds=timedelta(seconds=self._timeout)
166176
if self._timeout is not None
167177
else None,
178+
sampling_callback=self._sampling_callback,
179+
sampling_capabilities=self._sampling_capabilities,
168180
)
169181
)
170182
else:
@@ -176,6 +188,8 @@ async def _run(self):
176188
read_timeout_seconds=timedelta(seconds=self._sse_read_timeout)
177189
if self._sse_read_timeout is not None
178190
else None,
191+
sampling_callback=self._sampling_callback,
192+
sampling_capabilities=self._sampling_capabilities,
179193
)
180194
)
181195
await asyncio.wait_for(session.initialize(), timeout=self._timeout)

tests/unittests/tools/mcp_tool/test_mcp_toolset.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,3 +612,37 @@ async def test_read_resource(self, name, mime_type, content, encoding):
612612
assert result == contents
613613
self.mock_session.list_resources.assert_called_once()
614614
self.mock_session.read_resource.assert_called_once_with(uri=uri)
615+
616+
@pytest.mark.asyncio
617+
async def test_sampling_callback_invoked(self):
618+
619+
called = {"value": False}
620+
621+
async def mock_sampling_handler(messages, params=None, context=None):
622+
called["value"] = True
623+
624+
assert isinstance(messages, list)
625+
assert messages[0]["role"] == "user"
626+
627+
return {
628+
"model": "test-model",
629+
"role": "assistant",
630+
"content": {"type": "text", "text": "sampling response"},
631+
"stopReason": "endTurn",
632+
}
633+
634+
toolset = McpToolset(
635+
connection_params=StreamableHTTPConnectionParams(
636+
url="http://localhost:9999",
637+
timeout=10,
638+
),
639+
sampling_callback=mock_sampling_handler,
640+
)
641+
642+
messages = [{"role": "user", "content": {"type": "text", "text": "hello"}}]
643+
644+
result = await toolset._sampling_callback(messages)
645+
646+
assert called["value"] is True
647+
assert result["role"] == "assistant"
648+
assert result["content"]["text"] == "sampling response"

0 commit comments

Comments
 (0)