-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtask_handlers.py
More file actions
291 lines (222 loc) · 10.5 KB
/
task_handlers.py
File metadata and controls
291 lines (222 loc) · 10.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""Experimental task handler protocols for server -> client requests.
This module provides Protocol types and default handlers for when servers
send task-related requests to clients (the reverse of normal client -> server flow).
WARNING: These APIs are experimental and may change without notice.
Use cases:
- Server sends task-augmented sampling/elicitation request to client
- Client creates a local task, spawns background work, returns CreateTaskResult
- Server polls client's task status via tasks/get, tasks/result, etc.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Protocol
from pydantic import TypeAdapter
import mcp.types as types
from mcp.shared._context import RequestContext
from mcp.shared.session import RequestResponder
if TYPE_CHECKING:
from mcp.client.session import ClientSession
class GetTaskHandlerFnT(Protocol):
"""Handler for tasks/get requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext[ClientSession],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData: ... # pragma: no branch
class GetTaskResultHandlerFnT(Protocol):
"""Handler for tasks/result requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext[ClientSession],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData: ... # pragma: no branch
class ListTasksHandlerFnT(Protocol):
"""Handler for tasks/list requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext[ClientSession],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData: ... # pragma: no branch
class CancelTaskHandlerFnT(Protocol):
"""Handler for tasks/cancel requests from server.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext[ClientSession],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData: ... # pragma: no branch
class TaskAugmentedSamplingFnT(Protocol):
"""Handler for task-augmented sampling/createMessage requests from server.
When server sends a CreateMessageRequest with task field, this callback
is invoked. The callback should create a task, spawn background work,
and return CreateTaskResult immediately.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
class TaskAugmentedElicitationFnT(Protocol):
"""Handler for task-augmented elicitation/create requests from server.
When server sends an ElicitRequest with task field, this callback
is invoked. The callback should create a task, spawn background work,
and return CreateTaskResult immediately.
WARNING: This is experimental and may change without notice.
"""
async def __call__(
self,
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData: ... # pragma: no branch
async def default_get_task_handler(
context: RequestContext[ClientSession],
params: types.GetTaskRequestParams,
) -> types.GetTaskResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/get not supported",
)
async def default_get_task_result_handler(
context: RequestContext[ClientSession],
params: types.GetTaskPayloadRequestParams,
) -> types.GetTaskPayloadResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/result not supported",
)
async def default_list_tasks_handler(
context: RequestContext[ClientSession],
params: types.PaginatedRequestParams | None,
) -> types.ListTasksResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/list not supported",
)
async def default_cancel_task_handler(
context: RequestContext[ClientSession],
params: types.CancelTaskRequestParams,
) -> types.CancelTaskResult | types.ErrorData:
return types.ErrorData(
code=types.METHOD_NOT_FOUND,
message="tasks/cancel not supported",
)
async def default_task_augmented_sampling(
context: RequestContext[ClientSession],
params: types.CreateMessageRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Task-augmented sampling not supported",
)
async def default_task_augmented_elicitation(
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
task_metadata: types.TaskMetadata,
) -> types.CreateTaskResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Task-augmented elicitation not supported",
)
@dataclass
class ExperimentalTaskHandlers:
"""Container for experimental task handlers.
Groups all task-related handlers that handle server -> client requests.
This includes both pure task requests (get, list, cancel, result) and
task-augmented request handlers (sampling, elicitation with task field).
WARNING: These APIs are experimental and may change without notice.
Example:
handlers = ExperimentalTaskHandlers(
get_task=my_get_task_handler,
list_tasks=my_list_tasks_handler,
)
session = ClientSession(..., experimental_task_handlers=handlers)
"""
# Pure task request handlers
get_task: GetTaskHandlerFnT = field(default=default_get_task_handler)
get_task_result: GetTaskResultHandlerFnT = field(default=default_get_task_result_handler)
list_tasks: ListTasksHandlerFnT = field(default=default_list_tasks_handler)
cancel_task: CancelTaskHandlerFnT = field(default=default_cancel_task_handler)
# Task-augmented request handlers
augmented_sampling: TaskAugmentedSamplingFnT = field(default=default_task_augmented_sampling)
augmented_elicitation: TaskAugmentedElicitationFnT = field(default=default_task_augmented_elicitation)
def build_capability(self) -> types.ClientTasksCapability | None:
"""Build ClientTasksCapability from the configured handlers.
Returns a capability object that reflects which handlers are configured
(i.e., not using the default "not supported" handlers).
Returns:
ClientTasksCapability if any handlers are provided, None otherwise
"""
has_list = self.list_tasks is not default_list_tasks_handler
has_cancel = self.cancel_task is not default_cancel_task_handler
has_sampling = self.augmented_sampling is not default_task_augmented_sampling
has_elicitation = self.augmented_elicitation is not default_task_augmented_elicitation
# If no handlers are provided, return None
if not any([has_list, has_cancel, has_sampling, has_elicitation]):
return None
# Build requests capability if any request handlers are provided
requests_capability: types.ClientTasksRequestsCapability | None = None
if has_sampling or has_elicitation:
requests_capability = types.ClientTasksRequestsCapability(
sampling=types.TasksSamplingCapability(create_message=types.TasksCreateMessageCapability())
if has_sampling
else None,
elicitation=types.TasksElicitationCapability(create=types.TasksCreateElicitationCapability())
if has_elicitation
else None,
)
return types.ClientTasksCapability(
list=types.TasksListCapability() if has_list else None,
cancel=types.TasksCancelCapability() if has_cancel else None,
requests=requests_capability,
)
@staticmethod
def handles_request(request: types.ServerRequest) -> bool:
"""Check if this handler handles the given request type."""
return isinstance(
request,
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
)
async def handle_request(
self,
ctx: RequestContext[ClientSession],
responder: RequestResponder[types.ServerRequest, types.ClientResult],
) -> None:
"""Handle a task-related request from the server.
Call handles_request() first to check if this handler can handle the request.
"""
client_response_type: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
types.ClientResult | types.ErrorData
)
match responder.request:
case types.GetTaskRequest(params=params):
response = await self.get_task(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.GetTaskPayloadRequest(params=params):
response = await self.get_task_result(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.ListTasksRequest(params=params):
response = await self.list_tasks(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case types.CancelTaskRequest(params=params):
response = await self.cancel_task(ctx, params)
client_response = client_response_type.validate_python(response)
await responder.respond(client_response)
case _: # pragma: no cover
raise ValueError(f"Unhandled request type: {type(responder.request)}")
# Backwards compatibility aliases
default_task_augmented_sampling_callback = default_task_augmented_sampling
default_task_augmented_elicitation_callback = default_task_augmented_elicitation