-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathtasks.py
More file actions
202 lines (159 loc) · 6.43 KB
/
tasks.py
File metadata and controls
202 lines (159 loc) · 6.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Experimental client-side task support.
This module provides client methods for interacting with MCP tasks.
WARNING: These APIs are experimental and may change without notice.
Example:
# Call a tool as a task
result = await session.experimental.call_tool_as_task("tool_name", {"arg": "value"})
task_id = result.task.task_id
# Get task status
status = await session.experimental.get_task(task_id)
# Get task result when complete
if status.status == "completed":
result = await session.experimental.get_task_result(task_id, CallToolResult)
# List all tasks
tasks = await session.experimental.list_tasks()
# Cancel a task
await session.experimental.cancel_task(task_id)
"""
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypeVar
import mcp.types as types
from mcp.shared.experimental.tasks.polling import poll_until_terminal
from mcp.types._types import RequestParamsMeta
if TYPE_CHECKING:
from mcp.client.session import ClientSession
ResultT = TypeVar("ResultT", bound=types.Result)
class ExperimentalClientFeatures:
"""Experimental client features for tasks and other experimental APIs.
WARNING: These APIs are experimental and may change without notice.
Access via session.experimental:
status = await session.experimental.get_task(task_id)
"""
def __init__(self, session: "ClientSession") -> None:
self._session = session
async def call_tool_as_task(
self,
name: str,
arguments: dict[str, Any] | None = None,
*,
ttl: int = 60000,
meta: RequestParamsMeta | None = None,
) -> types.CreateTaskResult:
"""Call a tool as a task, returning a CreateTaskResult for polling.
This is a convenience method for calling tools that support task execution.
The server will return a task reference instead of the immediate result,
which can then be polled via `get_task()` and retrieved via `get_task_result()`.
Args:
name: The tool name
arguments: Tool arguments
ttl: Task time-to-live in milliseconds (default: 60000 = 1 minute)
meta: Optional metadata to include in the request
Returns:
CreateTaskResult containing the task reference
Example:
# Create task
result = await session.experimental.call_tool_as_task(
"long_running_tool", {"input": "data"}
)
task_id = result.task.task_id
# Poll for completion
while True:
status = await session.experimental.get_task(task_id)
if status.status == "completed":
break
await asyncio.sleep(0.5)
# Get result
final = await session.experimental.get_task_result(task_id, CallToolResult)
"""
return await self._session.send_request(
types.CallToolRequest(
params=types.CallToolRequestParams(
name=name,
arguments=arguments,
task=types.TaskMetadata(ttl=ttl),
_meta=meta,
),
),
types.CreateTaskResult,
)
async def get_task(self, task_id: str) -> types.GetTaskResult:
"""Get the current status of a task.
Args:
task_id: The task identifier
Returns:
GetTaskResult containing the task status and metadata
"""
return await self._session.send_request(
types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)),
types.GetTaskResult,
)
async def get_task_result(
self,
task_id: str,
result_type: type[ResultT],
) -> ResultT:
"""Get the result of a completed task.
The result type depends on the original request type:
- tools/call tasks return CallToolResult
- Other request types return their corresponding result type
Args:
task_id: The task identifier
result_type: The expected result type (e.g., CallToolResult)
Returns:
The task result, validated against result_type
"""
return await self._session.send_request(
types.GetTaskPayloadRequest(
params=types.GetTaskPayloadRequestParams(task_id=task_id),
),
result_type,
)
async def list_tasks(
self,
cursor: str | None = None,
) -> types.ListTasksResult:
"""List all tasks.
Args:
cursor: Optional pagination cursor
Returns:
ListTasksResult containing tasks and optional next cursor
"""
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
return await self._session.send_request(
types.ListTasksRequest(params=params),
types.ListTasksResult,
)
async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
"""Cancel a running task.
Args:
task_id: The task identifier
Returns:
CancelTaskResult with the updated task state
"""
return await self._session.send_request(
types.CancelTaskRequest(
params=types.CancelTaskRequestParams(task_id=task_id),
),
types.CancelTaskResult,
)
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
"""Poll a task until it reaches a terminal status.
Yields GetTaskResult for each poll, allowing the caller to react to
status changes (e.g., handle input_required). Exits when task reaches
a terminal status (completed, failed, cancelled).
Respects the pollInterval hint from the server.
Args:
task_id: The task identifier
Yields:
GetTaskResult for each poll
Example:
async for status in session.experimental.poll_task(task_id):
print(f"Status: {status.status}")
if status.status == "input_required":
# Handle elicitation request via tasks/result
pass
# Task is now terminal, get the result
result = await session.experimental.get_task_result(task_id, CallToolResult)
"""
async for status in poll_until_terminal(self.get_task, task_id):
yield status