Skip to content

Commit e814535

Browse files
test(client): cover resume path and streamable-http get-stream guard
1 parent d1f22c5 commit e814535

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

tests/client/test_session.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,28 @@ async def mock_server():
606606
assert result.protocol_version == LATEST_PROTOCOL_VERSION
607607

608608

609+
@pytest.mark.anyio
610+
async def test_client_session_resume_sets_initialize_result():
611+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
612+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
613+
614+
session = ClientSession(server_to_client_receive, client_to_server_send)
615+
assert session.initialize_result is None
616+
617+
resumed_result = InitializeResult(
618+
protocol_version=LATEST_PROTOCOL_VERSION,
619+
capabilities=ServerCapabilities(),
620+
server_info=Implementation(name="mock-server", version="0.1.0"),
621+
)
622+
session.resume(resumed_result)
623+
assert session.initialize_result == resumed_result
624+
625+
await client_to_server_send.aclose()
626+
await client_to_server_receive.aclose()
627+
await server_to_client_send.aclose()
628+
await server_to_client_receive.aclose()
629+
630+
609631
@pytest.mark.anyio
610632
@pytest.mark.parametrize(argnames="meta", argvalues=[None, {"toolMeta": "value"}])
611633
async def test_client_tool_call_with_meta(meta: RequestParamsMeta | None):

tests/shared/test_streamable_http.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1809,6 +1809,46 @@ def test_streamable_http_transport_includes_seeded_session_id_header():
18091809
assert headers["mcp-session-id"] == "resume-session-id"
18101810

18111811

1812+
@pytest.mark.anyio
1813+
async def test_streamable_http_client_resumption_starts_get_stream_once(monkeypatch: pytest.MonkeyPatch):
1814+
start_count = 0
1815+
1816+
async def fake_handle_get_stream(
1817+
self: StreamableHTTPTransport,
1818+
client: httpx.AsyncClient,
1819+
read_stream_writer: anyio.abc.ObjectSendStream[SessionMessage | Exception],
1820+
) -> None:
1821+
nonlocal start_count
1822+
start_count += 1
1823+
await anyio.sleep(0)
1824+
1825+
async def fake_post_writer(
1826+
self: StreamableHTTPTransport,
1827+
client: httpx.AsyncClient,
1828+
write_stream_reader: anyio.abc.ObjectReceiveStream[SessionMessage],
1829+
read_stream_writer: anyio.abc.ObjectSendStream[SessionMessage | Exception],
1830+
write_stream: anyio.abc.ObjectSendStream[SessionMessage],
1831+
start_get_stream: Any,
1832+
tg: anyio.abc.TaskGroup,
1833+
) -> None:
1834+
# Call twice; the second call should hit the early return guard.
1835+
start_get_stream()
1836+
start_get_stream()
1837+
await anyio.sleep(0)
1838+
1839+
monkeypatch.setattr(StreamableHTTPTransport, "handle_get_stream", fake_handle_get_stream)
1840+
monkeypatch.setattr(StreamableHTTPTransport, "post_writer", fake_post_writer)
1841+
1842+
async with streamable_http_client(
1843+
"http://localhost:8000/mcp",
1844+
session_id="resume-session-id",
1845+
terminate_on_close=False,
1846+
):
1847+
await anyio.sleep(0)
1848+
1849+
assert start_count == 1
1850+
1851+
18121852
@pytest.mark.anyio
18131853
async def test_priming_event_not_sent_for_old_protocol_version():
18141854
"""Test that _maybe_send_priming_event skips for old protocol versions (backwards compat)."""

0 commit comments

Comments
 (0)