Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/mcp/client/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import uuid
from collections.abc import AsyncIterator
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from types import TracebackType
Expand Down Expand Up @@ -50,12 +51,14 @@ async def _connect(self) -> AsyncIterator[TransportStreams]:

async with anyio.create_task_group() as tg:
# Start server in background
memory_session_id = uuid.uuid4().hex
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird to be forced to create a session_id before running the server, does stdio does that as well?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stdio doesn't have session IDs either. After thinking through it I think the problem is we actually shouldn't require session_ids for the default TaskStore, and instead no session ID means no task isolation.

Instead the default task support should only allow None session_ids if the server isn't in stateless mode. I think I was wrong before in assuming that a None session ID meant the server is in stateless mode, but in reality it could just mean we're in a different transport.

tg.start_soon(
lambda: actual_server.run(
server_read,
server_write,
actual_server.create_initialization_options(),
raise_exceptions=self._raise_exceptions,
session_id=memory_session_id,
)
)

Expand Down
5 changes: 4 additions & 1 deletion src/mcp/server/experimental/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,10 @@ async def work(task: ServerTaskContext) -> CallToolResult:
# Access task_group via TaskSupport - raises if not in run() context
task_group = support.task_group

task = await support.store.create_task(self.task_metadata, task_id)
session_id = self._session.session_id
if session_id is None:
raise RuntimeError("Session ID is required for task operations but session has no ID.")
task = await support.store.create_task(self.task_metadata, task_id, session_id=session_id)

task_ctx = ServerTaskContext(
task=task,
Expand Down
36 changes: 20 additions & 16 deletions src/mcp/server/experimental/task_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __init__(
queue: The message queue for elicitation/sampling
handler: The result handler for response routing (required for elicit/create_message)
"""
self._ctx = TaskContext(task=task, store=store)
session_id = session.session_id
if session_id is None:
raise RuntimeError("Session ID is required for task operations but session has no ID.")
self._session_id = session_id
self._ctx = TaskContext(task=task, store=store, session_id=session_id)
self._session = session
self._queue = queue
self._handler = handler
Expand Down Expand Up @@ -212,7 +216,7 @@ async def elicit(
raise RuntimeError("handler is required for elicit(). Pass handler= to ServerTaskContext.")

# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)

# Build the request using session's helper
request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
Expand All @@ -236,12 +240,12 @@ async def elicit(
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class():
# This path is tested in test_elicit_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
raise

async def elicit_url(
Expand Down Expand Up @@ -281,7 +285,7 @@ async def elicit_url(
raise RuntimeError("handler is required for elicit_url(). Pass handler= to ServerTaskContext.")

# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)

# Build the request using session's helper
request = self._session._build_elicit_url_request( # pyright: ignore[reportPrivateUsage]
Expand All @@ -306,10 +310,10 @@ async def elicit_url(
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
return ElicitResult.model_validate(response_data)
except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
raise

async def create_message(
Expand Down Expand Up @@ -364,7 +368,7 @@ async def create_message(
raise RuntimeError("handler is required for create_message(). Pass handler= to ServerTaskContext.")

# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)

# Build the request using session's helper
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
Expand Down Expand Up @@ -396,12 +400,12 @@ async def create_message(
try:
# Wait for response (routed back via TaskResultHandler)
response_data = await resolver.wait()
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
return CreateMessageResult.model_validate(response_data)
except anyio.get_cancelled_exc_class():
# This path is tested in test_create_message_restores_status_on_cancellation
# which verifies status is restored to "working" after cancellation.
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
raise

async def elicit_as_task(
Expand Down Expand Up @@ -437,7 +441,7 @@ async def elicit_as_task(
raise RuntimeError("handler is required for elicit_as_task()")

# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)

request = self._session._build_elicit_form_request( # pyright: ignore[reportPrivateUsage]
message=message,
Expand Down Expand Up @@ -474,11 +478,11 @@ async def elicit_as_task(
ElicitResult,
)

await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
return result

except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
raise

async def create_message_as_task(
Expand Down Expand Up @@ -533,7 +537,7 @@ async def create_message_as_task(
raise RuntimeError("handler is required for create_message_as_task()")

# Update status to input_required
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED)
await self._store.update_task(self.task_id, status=TASK_STATUS_INPUT_REQUIRED, session_id=self._session_id)

# Build request WITH task field for task-augmented sampling
request = self._session._build_create_message_request( # pyright: ignore[reportPrivateUsage]
Expand Down Expand Up @@ -579,9 +583,9 @@ async def create_message_as_task(
CreateMessageResult,
)

await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
return result

except anyio.get_cancelled_exc_class(): # pragma: no cover
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING)
await self._store.update_task(self.task_id, status=TASK_STATUS_WORKING, session_id=self._session_id)
raise
6 changes: 4 additions & 2 deletions src/mcp/server/experimental/task_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def handle(
request: GetTaskPayloadRequest,
session: ServerSession,
request_id: RequestId,
session_id: str,
) -> GetTaskPayloadResult:
"""Handle a tasks/result request.

Expand All @@ -94,22 +95,23 @@ async def handle(
request: The GetTaskPayloadRequest
session: The server session for sending messages
request_id: The request ID for relatedRequestId routing
session_id: Session identifier for access control.

Returns:
GetTaskPayloadResult with the task's final payload
"""
task_id = request.params.task_id

while True:
task = await self._store.get_task(task_id)
task = await self._store.get_task(task_id, session_id=session_id)
if task is None:
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}")

await self._deliver_queued_messages(task_id, session, request_id)

# If task is terminal, return result
if is_terminal(task.status):
result = await self._store.get_result(task_id)
result = await self._store.get_result(task_id, session_id=session_id)
# GetTaskPayloadResult is a Result with extra="allow"
# The stored result contains the actual payload data
# Per spec: tasks/result MUST include _meta with related-task metadata
Expand Down
21 changes: 17 additions & 4 deletions src/mcp/server/lowlevel/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,23 @@ def enable_tasks(
if on_cancel_task is not None:
self._add_request_handler("tasks/cancel", on_cancel_task)

def _require_session_id(ctx: ServerRequestContext[LifespanResultT]) -> str:
session_id = ctx.session.session_id
if session_id is None:
raise MCPError(
code=INVALID_PARAMS,
message="Session ID is required for task operations.",
)
return session_id

# Fill in defaults for any not provided
if not self._has_handler("tasks/get"):

async def _default_get_task(
ctx: ServerRequestContext[LifespanResultT], params: GetTaskRequestParams
) -> GetTaskResult:
task = await task_support.store.get_task(params.task_id)
session_id = _require_session_id(ctx)
task = await task_support.store.get_task(params.task_id, session_id=session_id)
if task is None:
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {params.task_id}")
return GetTaskResult(
Expand All @@ -180,8 +190,9 @@ async def _default_get_task_result(
ctx: ServerRequestContext[LifespanResultT], params: GetTaskPayloadRequestParams
) -> GetTaskPayloadResult:
assert ctx.request_id is not None
session_id = _require_session_id(ctx)
req = GetTaskPayloadRequest(params=params)
result = await task_support.handler.handle(req, ctx.session, ctx.request_id)
result = await task_support.handler.handle(req, ctx.session, ctx.request_id, session_id=session_id)
return result

self._add_request_handler("tasks/result", _default_get_task_result)
Expand All @@ -192,7 +203,8 @@ async def _default_list_tasks(
ctx: ServerRequestContext[LifespanResultT], params: PaginatedRequestParams | None
) -> ListTasksResult:
cursor = params.cursor if params else None
tasks, next_cursor = await task_support.store.list_tasks(cursor)
session_id = _require_session_id(ctx)
tasks, next_cursor = await task_support.store.list_tasks(cursor, session_id=session_id)
return ListTasksResult(tasks=tasks, next_cursor=next_cursor)

self._add_request_handler("tasks/list", _default_list_tasks)
Expand All @@ -202,7 +214,8 @@ async def _default_list_tasks(
async def _default_cancel_task(
ctx: ServerRequestContext[LifespanResultT], params: CancelTaskRequestParams
) -> CancelTaskResult:
result = await cancel_task(task_support.store, params.task_id)
session_id = _require_session_id(ctx)
result = await cancel_task(task_support.store, params.task_id, session_id=session_id)
return result

self._add_request_handler("tasks/cancel", _default_cancel_task)
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/server/lowlevel/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ async def run(
# the initialization lifecycle, but can do so with any available node
# rather than requiring initialization for each connection.
stateless: bool = False,
session_id: str | None = None,
):
async with AsyncExitStack() as stack:
lifespan_context = await stack.enter_async_context(self.lifespan(self))
Expand All @@ -380,6 +381,7 @@ async def run(
write_stream,
initialization_options,
stateless=stateless,
session_id=session_id,
)
)

Expand Down
3 changes: 3 additions & 0 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,12 @@ def __init__(
write_stream: MemoryObjectSendStream[SessionMessage],
init_options: InitializationOptions,
stateless: bool = False,
*,
session_id: str | None = None,
) -> None:
super().__init__(read_stream, write_stream)
self._stateless = stateless
self.session_id = session_id
self._initialization_state = (
InitializationState.Initialized if stateless else InitializationState.NotInitialized
)
Expand Down
2 changes: 2 additions & 0 deletions src/mcp/server/streamable_http_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
write_stream,
self.app.create_initialization_options(),
stateless=True,
session_id=None, # No session in stateless mode
)
except Exception: # pragma: no cover
logger.exception("Stateless session crashed")
Expand Down Expand Up @@ -240,6 +241,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
write_stream,
self.app.create_initialization_options(),
stateless=False,
session_id=http_transport.mcp_session_id,
)

if idle_scope.cancelled_caught:
Expand Down
14 changes: 9 additions & 5 deletions src/mcp/shared/experimental/tasks/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ class TaskContext:
use ServerTaskContext from mcp.server.experimental.

Example (distributed worker):
async def worker_job(task_id: str):
async def worker_job(task_id: str, session_id: str):
store = RedisTaskStore(redis_url)
task = await store.get_task(task_id)
ctx = TaskContext(task=task, store=store)
task = await store.get_task(task_id, session_id=session_id)
ctx = TaskContext(task=task, store=store, session_id=session_id)

await ctx.update_status("Working...")
result = await do_work()
await ctx.complete(result)
"""

def __init__(self, task: Task, store: TaskStore):
def __init__(self, task: Task, store: TaskStore, *, session_id: str):
self._task = task
self._store = store
self._session_id = session_id
self._cancelled = False

@property
Expand Down Expand Up @@ -68,6 +69,7 @@ async def update_status(self, message: str) -> None:
self._task = await self._store.update_task(
self.task_id,
status_message=message,
session_id=self._session_id,
)

async def complete(self, result: Result) -> None:
Expand All @@ -76,10 +78,11 @@ async def complete(self, result: Result) -> None:
Args:
result: The task result
"""
await self._store.store_result(self.task_id, result)
await self._store.store_result(self.task_id, result, session_id=self._session_id)
self._task = await self._store.update_task(
self.task_id,
status=TASK_STATUS_COMPLETED,
session_id=self._session_id,
)

async def fail(self, error: str) -> None:
Expand All @@ -92,4 +95,5 @@ async def fail(self, error: str) -> None:
self.task_id,
status=TASK_STATUS_FAILED,
status_message=error,
session_id=self._session_id,
)
Loading