diff --git a/src/agents/mcp/server.py b/src/agents/mcp/server.py index b8c196ce54..edc4695e2e 100644 --- a/src/agents/mcp/server.py +++ b/src/agents/mcp/server.py @@ -361,6 +361,7 @@ def __init__( self.tool_filter = tool_filter self._serialize_session_requests = False + self._get_session_id: GetSessionIdCallback | None = None async def _maybe_serialize_request(self, func: Callable[[], Awaitable[T]]) -> T: if not self._serialize_session_requests: @@ -515,7 +516,9 @@ async def connect(self): # streamablehttp_client returns (read, write, get_session_id) # sse_client returns (read, write) - read, write, *_ = transport + read, write, *rest = transport + # Capture the session-id callback when present (streamablehttp_client only). + self._get_session_id = rest[0] if rest and callable(rest[0]) else None session = await self.exit_stack.enter_async_context( ClientSession( @@ -780,6 +783,7 @@ async def cleanup(self): logger.error(f"Error cleaning up server: {e}") finally: self.session = None + self._get_session_id = None class MCPServerStdioParams(TypedDict): @@ -1348,3 +1352,29 @@ async def call_tool( def name(self) -> str: """A readable name for the server.""" return self._name + + @property + def session_id(self) -> str | None: + """The MCP session ID assigned by the server, or None if not yet connected + or if the server did not issue a session ID. + + The session ID is stable for the lifetime of this server instance's connection. + You can persist it and pass it back via the Mcp-Session-Id request header + (params["headers"]) on a new MCPServerStreamableHttp instance to resume + the same server-side session across process restarts or stateless workers. + + Example:: + + async with MCPServerStreamableHttp(params={"url": url}) as server: + session_id = server.session_id + + # In a new worker / process: + async with MCPServerStreamableHttp( + params={"url": url, "headers": {"Mcp-Session-Id": session_id}} + ) as server: + # Resumes the same server-side session. + ... + """ + if self._get_session_id is None: + return None + return self._get_session_id() diff --git a/tests/mcp/test_streamable_http_session_id.py b/tests/mcp/test_streamable_http_session_id.py new file mode 100644 index 0000000000..a98013b8f1 --- /dev/null +++ b/tests/mcp/test_streamable_http_session_id.py @@ -0,0 +1,115 @@ +"""Tests for MCPServerStreamableHttp.session_id property (issue #924).""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agents.mcp import MCPServerStreamableHttp + + +class TestStreamableHttpSessionId: + """Tests that the session_id property is correctly exposed.""" + + def test_session_id_is_none_before_connect(self): + """session_id should be None when the server has not been connected yet.""" + server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"}) + assert server.session_id is None + + def test_session_id_returns_none_when_callback_is_none(self): + """session_id should be None when _get_session_id callback is None.""" + server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"}) + server._get_session_id = None + assert server.session_id is None + + def test_session_id_returns_callback_value(self): + """session_id should return the value from the get_session_id callback.""" + server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"}) + mock_get_session_id = MagicMock(return_value="test-session-abc123") + server._get_session_id = mock_get_session_id + assert server.session_id == "test-session-abc123" + mock_get_session_id.assert_called_once() + + def test_session_id_returns_none_when_callback_returns_none(self): + """session_id should return None when the callback itself returns None.""" + server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"}) + mock_get_session_id = MagicMock(return_value=None) + server._get_session_id = mock_get_session_id + assert server.session_id is None + + def test_session_id_reflects_updated_callback_value(self): + """session_id should reflect the latest value from the callback each time.""" + server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"}) + call_count = 0 + + def changing_callback() -> str | None: + nonlocal call_count + call_count += 1 + return f"session-{call_count}" + + server._get_session_id = changing_callback + assert server.session_id == "session-1" + assert server.session_id == "session-2" + + @pytest.mark.asyncio + async def test_connect_captures_get_session_id_callback(self): + """connect() should capture the third element of the transport tuple as _get_session_id.""" + server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"}) + + mock_read = AsyncMock() + mock_write = AsyncMock() + mock_get_session_id = MagicMock(return_value="captured-session-xyz") + + mock_initialize_result = MagicMock() + mock_session = AsyncMock() + mock_session.initialize = AsyncMock(return_value=mock_initialize_result) + + # Simulate the full 3-tuple that streamablehttp_client returns + transport_tuple = (mock_read, mock_write, mock_get_session_id) + + with patch("agents.mcp.server.ClientSession") as mock_client_session_cls: + mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None) + + with patch.object( + server, + "create_streams", + ) as mock_create_streams: + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=transport_tuple) + mock_cm.__aexit__ = AsyncMock(return_value=None) + mock_create_streams.return_value = mock_cm + + with patch.object(server.exit_stack, "enter_async_context") as mock_enter: + # First call returns transport, second call returns session + mock_enter.side_effect = [transport_tuple, mock_session] + mock_session.initialize.return_value = mock_initialize_result + + await server.connect() + + # After connect, _get_session_id should be the callable from the transport + assert server._get_session_id is mock_get_session_id + assert server.session_id == "captured-session-xyz" + + +@pytest.mark.asyncio +async def test_session_id_is_none_after_cleanup(): + """session_id must return None after disconnect (cleanup clears _get_session_id).""" + server = MCPServerStreamableHttp(params={"url": "http://localhost:8000/mcp"}) + + mock_get_session_id = MagicMock(return_value="session-to-clear") + # Manually inject a session-id callback to simulate a connected state + server._get_session_id = mock_get_session_id + server.session = MagicMock() # pretend connected + + assert server.session_id == "session-to-clear" + + # Now simulate cleanup completing (exit_stack.aclose is a no-op here) + with patch.object(server.exit_stack, "aclose", new_callable=AsyncMock): + await server.cleanup() + + # After cleanup both session and _get_session_id must be None + assert server.session is None + assert server._get_session_id is None + assert server.session_id is None