-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtask_result_handler.py
More file actions
221 lines (182 loc) · 7.99 KB
/
task_result_handler.py
File metadata and controls
221 lines (182 loc) · 7.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""TaskResultHandler - Integrated handler for tasks/result endpoint.
This implements the dequeue-send-wait pattern from the MCP Tasks spec:
1. Dequeue all pending messages for the task
2. Send them to the client via transport with relatedRequestId routing
3. Wait if task is not in terminal state
4. Return final result when task completes
This is the core of the task message queue pattern.
"""
import logging
from typing import Any
import anyio
from mcp.server.session import ServerSession
from mcp.shared.exceptions import MCPError
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
from mcp.shared.experimental.tasks.resolver import Resolver
from mcp.shared.experimental.tasks.store import TaskStore
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.types import (
INVALID_PARAMS,
ErrorData,
GetTaskPayloadRequest,
GetTaskPayloadResult,
RelatedTaskMetadata,
RequestId,
)
logger = logging.getLogger(__name__)
class TaskResultHandler:
"""Handler for tasks/result that implements the message queue pattern.
This handler:
1. Dequeues pending messages (elicitations, notifications) for the task
2. Sends them to the client via the response stream
3. Waits for responses and resolves them back to callers
4. Blocks until task reaches terminal state
5. Returns the final result
Usage:
# Create handler with store and queue
handler = TaskResultHandler(task_store, message_queue)
# Register as a handler with the lowlevel server
async def handle_task_result(ctx, params):
return await handler.handle(
GetTaskPayloadRequest(params=params), ctx.session, ctx.request_id
)
server = Server(handlers=[
RequestHandler("tasks/result", handler=handle_task_result),
])
"""
def __init__(
self,
store: TaskStore,
queue: TaskMessageQueue,
):
self._store = store
self._queue = queue
# Map from internal request ID to resolver for routing responses
self._pending_requests: dict[RequestId, Resolver[dict[str, Any]]] = {}
async def send_message(
self,
session: ServerSession,
message: SessionMessage,
) -> None:
"""Send a message via the session.
This is a helper for delivering queued task messages.
"""
await session.send_message(message)
async def handle(
self,
request: GetTaskPayloadRequest,
session: ServerSession,
request_id: RequestId,
) -> GetTaskPayloadResult:
"""Handle a tasks/result request.
This implements the dequeue-send-wait loop:
1. Dequeue all pending messages
2. Send each via transport with relatedRequestId = this request's ID
3. If task not terminal, wait for status change
4. Loop until task is terminal
5. Return final result
Args:
request: The GetTaskPayloadRequest
session: The server session for sending messages
request_id: The request ID for relatedRequestId routing
Returns:
GetTaskPayloadResult with the task's final payload
"""
task_id = request.params.task_id
while True:
task = await self._store.get_task(task_id)
if task is None:
raise MCPError(code=INVALID_PARAMS, message=f"Task not found: {task_id}")
await self._deliver_queued_messages(task_id, session, request_id)
# If task is terminal, return result
if is_terminal(task.status):
result = await self._store.get_result(task_id)
# GetTaskPayloadResult is a Result with extra="allow"
# The stored result contains the actual payload data
# Per spec: tasks/result MUST include _meta with related-task metadata
related_task = RelatedTaskMetadata(task_id=task_id)
related_task_meta: dict[str, Any] = {RELATED_TASK_METADATA_KEY: related_task.model_dump(by_alias=True)}
if result is not None:
result_data = result.model_dump(by_alias=True)
existing_meta: dict[str, Any] = result_data.get("_meta") or {}
result_data["_meta"] = {**existing_meta, **related_task_meta}
return GetTaskPayloadResult.model_validate(result_data)
return GetTaskPayloadResult.model_validate({"_meta": related_task_meta})
# Wait for task update (status change or new messages)
await self._wait_for_task_update(task_id)
async def _deliver_queued_messages(
self,
task_id: str,
session: ServerSession,
request_id: RequestId,
) -> None:
"""Dequeue and send all pending messages for a task.
Each message is sent via the session's write stream with
relatedRequestId set so responses route back to this stream.
"""
while True:
message = await self._queue.dequeue(task_id)
if message is None:
break
# If this is a request (not notification), wait for response
if message.type == "request" and message.resolver is not None:
# Store the resolver so we can route the response back
original_id = message.original_request_id
if original_id is not None:
self._pending_requests[original_id] = message.resolver
logger.debug("Delivering queued message for task %s: %s", task_id, message.type)
# Send the message with relatedRequestId for routing
session_message = SessionMessage(
message=message.message,
metadata=ServerMessageMetadata(related_request_id=request_id),
)
await self.send_message(session, session_message)
async def _wait_for_task_update(self, task_id: str) -> None:
"""Wait for task to be updated (status change or new message).
Races between store update and queue message - first one wins.
"""
async with anyio.create_task_group() as tg:
async def wait_for_store() -> None:
try:
await self._store.wait_for_update(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()
async def wait_for_queue() -> None:
try:
await self._queue.wait_for_message(task_id)
except Exception:
pass
finally:
tg.cancel_scope.cancel()
tg.start_soon(wait_for_store)
tg.start_soon(wait_for_queue)
def route_response(self, request_id: RequestId, response: dict[str, Any]) -> bool:
"""Route a response back to the waiting resolver.
This is called when a response arrives for a queued request.
Args:
request_id: The request ID from the response
response: The response data
Returns:
True if response was routed, False if no pending request
"""
resolver = self._pending_requests.pop(request_id, None)
if resolver is not None and not resolver.done():
resolver.set_result(response)
return True
return False
def route_error(self, request_id: RequestId, error: ErrorData) -> bool:
"""Route an error back to the waiting resolver.
Args:
request_id: The request ID from the error response
error: The error data
Returns:
True if error was routed, False if no pending request
"""
resolver = self._pending_requests.pop(request_id, None)
if resolver is not None and not resolver.done():
resolver.set_exception(MCPError.from_error_data(error))
return True
return False