Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/agents/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def __init__(

self.tool_filter = tool_filter
self._serialize_session_requests = False
self._get_session_id: GetSessionIdCallback | None = None

async def _maybe_serialize_request(self, func: Callable[[], Awaitable[T]]) -> T:
if not self._serialize_session_requests:
Expand Down Expand Up @@ -515,7 +516,9 @@ async def connect(self):
# streamablehttp_client returns (read, write, get_session_id)
# sse_client returns (read, write)

read, write, *_ = transport
read, write, *rest = transport
# Capture the session-id callback when present (streamablehttp_client only).
self._get_session_id = rest[0] if rest and callable(rest[0]) else None

session = await self.exit_stack.enter_async_context(
ClientSession(
Expand Down Expand Up @@ -780,6 +783,7 @@ async def cleanup(self):
logger.error(f"Error cleaning up server: {e}")
finally:
self.session = None
self._get_session_id = None


class MCPServerStdioParams(TypedDict):
Expand Down Expand Up @@ -1348,3 +1352,29 @@ async def call_tool(
def name(self) -> str:
"""A readable name for the server."""
return self._name

@property
def session_id(self) -> str | None:
"""The MCP session ID assigned by the server, or None if not yet connected
or if the server did not issue a session ID.

The session ID is stable for the lifetime of this server instance's connection.
You can persist it and pass it back via the Mcp-Session-Id request header
(params["headers"]) on a new MCPServerStreamableHttp instance to resume
the same server-side session across process restarts or stateless workers.

Example::

async with MCPServerStreamableHttp(params={"url": url}) as server:
session_id = server.session_id

# In a new worker / process:
async with MCPServerStreamableHttp(
params={"url": url, "headers": {"Mcp-Session-Id": session_id}}
) as server:
# Resumes the same server-side session.
...
"""
if self._get_session_id is None:
return None
return self._get_session_id()
Comment thread
seratch marked this conversation as resolved.
115 changes: 115 additions & 0 deletions tests/mcp/test_streamable_http_session_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for MCPServerStreamableHttp.session_id property (issue #924)."""

from __future__ import annotations

from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from agents.mcp import MCPServerStreamableHttp


class TestStreamableHttpSessionId:
"""Tests that the session_id property is correctly exposed."""

def test_session_id_is_none_before_connect(self):
"""session_id should be None when the server has not been connected yet."""
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
assert server.session_id is None

def test_session_id_returns_none_when_callback_is_none(self):
"""session_id should be None when _get_session_id callback is None."""
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
server._get_session_id = None
assert server.session_id is None

def test_session_id_returns_callback_value(self):
"""session_id should return the value from the get_session_id callback."""
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
mock_get_session_id = MagicMock(return_value="test-session-abc123")
server._get_session_id = mock_get_session_id
assert server.session_id == "test-session-abc123"
mock_get_session_id.assert_called_once()

def test_session_id_returns_none_when_callback_returns_none(self):
"""session_id should return None when the callback itself returns None."""
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
mock_get_session_id = MagicMock(return_value=None)
server._get_session_id = mock_get_session_id
assert server.session_id is None

def test_session_id_reflects_updated_callback_value(self):
"""session_id should reflect the latest value from the callback each time."""
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
call_count = 0

def changing_callback() -> str | None:
nonlocal call_count
call_count += 1
return f"session-{call_count}"

server._get_session_id = changing_callback
assert server.session_id == "session-1"
assert server.session_id == "session-2"

@pytest.mark.asyncio
async def test_connect_captures_get_session_id_callback(self):
"""connect() should capture the third element of the transport tuple as _get_session_id."""
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})

mock_read = AsyncMock()
mock_write = AsyncMock()
mock_get_session_id = MagicMock(return_value="captured-session-xyz")

mock_initialize_result = MagicMock()
mock_session = AsyncMock()
mock_session.initialize = AsyncMock(return_value=mock_initialize_result)

# Simulate the full 3-tuple that streamablehttp_client returns
transport_tuple = (mock_read, mock_write, mock_get_session_id)

with patch("agents.mcp.server.ClientSession") as mock_client_session_cls:
mock_client_session_cls.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None)

with patch.object(
server,
"create_streams",
) as mock_create_streams:
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=transport_tuple)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_create_streams.return_value = mock_cm

with patch.object(server.exit_stack, "enter_async_context") as mock_enter:
# First call returns transport, second call returns session
mock_enter.side_effect = [transport_tuple, mock_session]
mock_session.initialize.return_value = mock_initialize_result

await server.connect()

# After connect, _get_session_id should be the callable from the transport
assert server._get_session_id is mock_get_session_id
assert server.session_id == "captured-session-xyz"


@pytest.mark.asyncio
async def test_session_id_is_none_after_cleanup():
"""session_id must return None after disconnect (cleanup clears _get_session_id)."""
server = MCPServerStreamableHttp(params={"url": "http://localhost:8000/mcp"})

mock_get_session_id = MagicMock(return_value="session-to-clear")
# Manually inject a session-id callback to simulate a connected state
server._get_session_id = mock_get_session_id
server.session = MagicMock() # pretend connected

assert server.session_id == "session-to-clear"

# Now simulate cleanup completing (exit_stack.aclose is a no-op here)
with patch.object(server.exit_stack, "aclose", new_callable=AsyncMock):
await server.cleanup()

# After cleanup both session and _get_session_id must be None
assert server.session is None
assert server._get_session_id is None
assert server.session_id is None