Skip to content

Commit 2253918

Browse files
author
Ubuntu
committed
feat(mcp): expose session_id on MCPServerStreamableHttp
Capture the get_session_id callback returned by streamablehttp_client in the base connect() method and expose it as a session_id property on MCPServerStreamableHttp. Closes #924
1 parent 8d3aa15 commit 2253918

2 files changed

Lines changed: 125 additions & 1 deletion

File tree

src/agents/mcp/server.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def __init__(
336336
self._tools_list: list[MCPTool] | None = None
337337

338338
self.tool_filter = tool_filter
339+
self._get_session_id: GetSessionIdCallback | None = None
339340

340341
async def _apply_tool_filter(
341342
self,
@@ -490,7 +491,9 @@ async def connect(self):
490491
# streamablehttp_client returns (read, write, get_session_id)
491492
# sse_client returns (read, write)
492493

493-
read, write, *_ = transport
494+
read, write, *rest = transport
495+
# Capture the session-id callback when present (streamablehttp_client only).
496+
self._get_session_id = rest[0] if rest and callable(rest[0]) else None
494497

495498
session = await self.exit_stack.enter_async_context(
496499
ClientSession(
@@ -1123,3 +1126,29 @@ def create_streams(
11231126
def name(self) -> str:
11241127
"""A readable name for the server."""
11251128
return self._name
1129+
1130+
@property
1131+
def session_id(self) -> str | None:
1132+
"""The MCP session ID assigned by the server, or None if not yet connected
1133+
or if the server did not issue a session ID.
1134+
1135+
The session ID is stable for the lifetime of this server instance's connection.
1136+
You can persist it and pass it back via the Mcp-Session-Id request header
1137+
(params["headers"]) on a new MCPServerStreamableHttp instance to resume
1138+
the same server-side session across process restarts or stateless workers.
1139+
1140+
Example::
1141+
1142+
async with MCPServerStreamableHttp(params={"url": url}) as server:
1143+
session_id = server.session_id
1144+
1145+
# In a new worker / process:
1146+
async with MCPServerStreamableHttp(
1147+
params={"url": url, "headers": {"Mcp-Session-Id": session_id}}
1148+
) as server:
1149+
# Resumes the same server-side session.
1150+
...
1151+
"""
1152+
if self._get_session_id is None:
1153+
return None
1154+
return self._get_session_id()
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""Tests for MCPServerStreamableHttp.session_id property (issue #924)."""
2+
3+
from __future__ import annotations
4+
5+
from unittest.mock import AsyncMock, MagicMock, patch
6+
7+
import pytest
8+
9+
from agents.mcp import MCPServerStreamableHttp
10+
11+
12+
class TestStreamableHttpSessionId:
13+
"""Tests that the session_id property is correctly exposed."""
14+
15+
def test_session_id_is_none_before_connect(self):
16+
"""session_id should be None when the server has not been connected yet."""
17+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
18+
assert server.session_id is None
19+
20+
def test_session_id_returns_none_when_callback_is_none(self):
21+
"""session_id should be None when _get_session_id callback is None."""
22+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
23+
server._get_session_id = None
24+
assert server.session_id is None
25+
26+
def test_session_id_returns_callback_value(self):
27+
"""session_id should return the value from the get_session_id callback."""
28+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
29+
mock_get_session_id = MagicMock(return_value="test-session-abc123")
30+
server._get_session_id = mock_get_session_id
31+
assert server.session_id == "test-session-abc123"
32+
mock_get_session_id.assert_called_once()
33+
34+
def test_session_id_returns_none_when_callback_returns_none(self):
35+
"""session_id should return None when the callback itself returns None."""
36+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
37+
mock_get_session_id = MagicMock(return_value=None)
38+
server._get_session_id = mock_get_session_id
39+
assert server.session_id is None
40+
41+
def test_session_id_reflects_updated_callback_value(self):
42+
"""session_id should reflect the latest value from the callback each time."""
43+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
44+
call_count = 0
45+
46+
def changing_callback() -> str | None:
47+
nonlocal call_count
48+
call_count += 1
49+
return f"session-{call_count}"
50+
51+
server._get_session_id = changing_callback
52+
assert server.session_id == "session-1"
53+
assert server.session_id == "session-2"
54+
55+
@pytest.mark.asyncio
56+
async def test_connect_captures_get_session_id_callback(self):
57+
"""connect() should capture the third element of the transport tuple as _get_session_id."""
58+
server = MCPServerStreamableHttp(params={"url": "http://localhost:9999/mcp"})
59+
60+
mock_read = AsyncMock()
61+
mock_write = AsyncMock()
62+
mock_get_session_id = MagicMock(return_value="captured-session-xyz")
63+
64+
mock_initialize_result = MagicMock()
65+
mock_session = AsyncMock()
66+
mock_session.initialize = AsyncMock(return_value=mock_initialize_result)
67+
68+
# Simulate the full 3-tuple that streamablehttp_client returns
69+
transport_tuple = (mock_read, mock_write, mock_get_session_id)
70+
71+
with patch("agents.mcp.server.ClientSession") as mock_client_session_cls:
72+
mock_client_session_cls.return_value.__aenter__ = AsyncMock(
73+
return_value=mock_session
74+
)
75+
mock_client_session_cls.return_value.__aexit__ = AsyncMock(return_value=None)
76+
77+
with patch.object(
78+
server,
79+
"create_streams",
80+
) as mock_create_streams:
81+
mock_cm = MagicMock()
82+
mock_cm.__aenter__ = AsyncMock(return_value=transport_tuple)
83+
mock_cm.__aexit__ = AsyncMock(return_value=None)
84+
mock_create_streams.return_value = mock_cm
85+
86+
with patch.object(server.exit_stack, "enter_async_context") as mock_enter:
87+
# First call returns transport, second call returns session
88+
mock_enter.side_effect = [transport_tuple, mock_session]
89+
mock_session.initialize.return_value = mock_initialize_result
90+
91+
await server.connect()
92+
93+
# After connect, _get_session_id should be the callable from the transport
94+
assert server._get_session_id is mock_get_session_id
95+
assert server.session_id == "captured-session-xyz"

0 commit comments

Comments
 (0)