Skip to content

Commit 14a4883

Browse files
committed
fix: fail fast when session is not started
1 parent 5d82649 commit 14a4883

2 files changed

Lines changed: 38 additions & 3 deletions

File tree

src/mcp/shared/session.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Generic, Protocol, TypeVar
88

99
import anyio
10+
from anyio.abc import TaskGroup
1011
from anyio.streams.memory import MemoryObjectSendStream
1112
from opentelemetry.trace import SpanKind
1213
from pydantic import BaseModel, TypeAdapter
@@ -156,10 +157,22 @@ def __init__(
156157
self._session_read_timeout_seconds = read_timeout_seconds
157158
self._progress_callbacks = {}
158159
self._exit_stack = AsyncExitStack()
160+
self._task_group: TaskGroup = anyio.create_task_group()
161+
self._started = False
162+
163+
def _require_started(self) -> None:
164+
if not self._started:
165+
raise RuntimeError(
166+
"Session is not running. Use it as an async context manager "
167+
"(e.g. `async with ClientSession(...) as session:`)."
168+
)
159169

160170
async def __aenter__(self) -> Self:
171+
if self._started:
172+
raise RuntimeError("Session is already running")
161173
self._task_group = anyio.create_task_group()
162174
await self._task_group.__aenter__()
175+
self._started = True
163176
self._task_group.start_soon(self._receive_loop)
164177
return self
165178

@@ -174,9 +187,12 @@ async def __aexit__(
174187
# would be very surprising behavior), so make sure to cancel the tasks
175188
# in the task group.
176189
self._task_group.cancel_scope.cancel()
177-
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
178-
await resync_tracer()
179-
return result
190+
try:
191+
result = await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
192+
await resync_tracer()
193+
return result
194+
finally:
195+
self._started = False
180196

181197
async def send_request(
182198
self,
@@ -193,6 +209,7 @@ async def send_request(
193209
194210
Do not use this method to emit notifications! Use send_notification() instead.
195211
"""
212+
self._require_started()
196213
request_id = self._request_id
197214
self._request_id = request_id + 1
198215

@@ -255,6 +272,7 @@ async def send_notification(
255272
related_request_id: RequestId | None = None,
256273
) -> None:
257274
"""Emits a notification, which is a one-way message that does not expect a response."""
275+
self._require_started()
258276
# Some transport implementations may need to set the related_request_id
259277
# to attribute to the notifications to the request that triggered them.
260278
jsonrpc_notification = JSONRPCNotification(

tests/client/test_session.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,23 @@ async def message_handler( # pragma: no cover
140140
assert isinstance(initialized_notification, InitializedNotification)
141141

142142

143+
@pytest.mark.anyio
144+
async def test_client_session_requires_context_manager():
145+
client_to_server_send, _client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)
146+
_server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1)
147+
148+
async with (
149+
client_to_server_send,
150+
_client_to_server_receive,
151+
_server_to_client_send,
152+
server_to_client_receive,
153+
):
154+
session = ClientSession(server_to_client_receive, client_to_server_send)
155+
156+
with pytest.raises(RuntimeError, match="async context manager"):
157+
await session.initialize()
158+
159+
143160
@pytest.mark.anyio
144161
async def test_client_session_custom_client_info():
145162
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1)

0 commit comments

Comments
 (0)