-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Expand file tree
/
Copy pathsession_features.py
More file actions
209 lines (168 loc) · 7.43 KB
/
session_features.py
File metadata and controls
209 lines (168 loc) · 7.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""Experimental server session features for server→client task operations.
This module provides the server-side equivalent of ExperimentalClientFeatures,
allowing the server to send task-augmented requests to the client and poll for results.
WARNING: These APIs are experimental and may change without notice.
"""
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any, TypeVar
from mcp import types
from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages
from mcp.shared.experimental.tasks.capabilities import (
require_task_augmented_elicitation,
require_task_augmented_sampling,
)
from mcp.shared.experimental.tasks.polling import poll_until_terminal
if TYPE_CHECKING:
from mcp.server.session import ServerSession
ResultT = TypeVar("ResultT", bound=types.Result)
class ExperimentalServerSessionFeatures:
"""Experimental server session features for server→client task operations.
This provides the server-side equivalent of ExperimentalClientFeatures,
allowing the server to send task-augmented requests to the client and
poll for results.
WARNING: These APIs are experimental and may change without notice.
Access via session.experimental:
result = await session.experimental.elicit_as_task(...)
"""
def __init__(self, session: "ServerSession") -> None:
self._session = session
async def get_task(self, task_id: str) -> types.GetTaskResult:
"""Send tasks/get to the client to get task status.
Args:
task_id: The task identifier
Returns:
GetTaskResult containing the task status
"""
return await self._session.send_request(
types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)),
types.GetTaskResult,
)
async def get_task_result(
self,
task_id: str,
result_type: type[ResultT],
) -> ResultT:
"""Send tasks/result to the client to retrieve the final result.
Args:
task_id: The task identifier
result_type: The expected result type
Returns:
The task result, validated against result_type
"""
return await self._session.send_request(
types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id)),
result_type,
)
async def poll_task(self, task_id: str) -> AsyncIterator[types.GetTaskResult]:
"""Poll a client task until it reaches terminal status.
Yields GetTaskResult for each poll, allowing the caller to react to
status changes. Exits when task reaches a terminal status.
Respects the pollInterval hint from the client.
Args:
task_id: The task identifier
Yields:
GetTaskResult for each poll
"""
async for status in poll_until_terminal(self.get_task, task_id):
yield status
async def elicit_as_task(
self,
message: str,
requested_schema: types.ElicitRequestedSchema,
*,
ttl: int = 60000,
) -> types.ElicitResult:
"""Send a task-augmented elicitation to the client and poll until complete.
The client will create a local task, process the elicitation asynchronously,
and return the result when ready. This method handles the full flow:
1. Send elicitation with task field
2. Receive CreateTaskResult from client
3. Poll client's task until terminal
4. Retrieve and return the final ElicitResult
Args:
message: The message to present to the user
requested_schema: Schema defining the expected response
ttl: Task time-to-live in milliseconds
Returns:
The client's elicitation response
Raises:
MCPError: If client doesn't support task-augmented elicitation
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_elicitation(client_caps)
create_result = await self._session.send_request(
types.ElicitRequest(
params=types.ElicitRequestFormParams(
message=message,
requested_schema=requested_schema,
task=types.TaskMetadata(ttl=ttl),
)
),
types.CreateTaskResult,
)
task_id = create_result.task.task_id
async for _ in self.poll_task(task_id):
pass
return await self.get_task_result(task_id, types.ElicitResult)
async def create_message_as_task(
self,
messages: list[types.SamplingMessage],
*,
max_tokens: int,
ttl: int = 60000,
system_prompt: str | None = None,
include_context: types.IncludeContext | None = None,
temperature: float | None = None,
stop_sequences: list[str] | None = None,
metadata: dict[str, Any] | None = None,
model_preferences: types.ModelPreferences | None = None,
tools: list[types.Tool] | None = None,
tool_choice: types.ToolChoice | None = None,
) -> types.CreateMessageResult:
"""Send a task-augmented sampling request and poll until complete.
The client will create a local task, process the sampling request
asynchronously, and return the result when ready.
Args:
messages: The conversation messages for sampling
max_tokens: Maximum tokens in the response
ttl: Task time-to-live in milliseconds
system_prompt: Optional system prompt
include_context: Context inclusion strategy
temperature: Sampling temperature
stop_sequences: Stop sequences
metadata: Additional metadata
model_preferences: Model selection preferences
tools: Optional list of tools the LLM can use during sampling
tool_choice: Optional control over tool usage behavior
Returns:
The sampling result from the client
Raises:
MCPError: If client doesn't support task-augmented sampling or tools
ValueError: If tool_use or tool_result message structure is invalid
"""
client_caps = self._session.client_params.capabilities if self._session.client_params else None
require_task_augmented_sampling(client_caps)
validate_sampling_tools(client_caps, tools, tool_choice)
validate_tool_use_result_messages(messages)
create_result = await self._session.send_request(
types.CreateMessageRequest(
params=types.CreateMessageRequestParams(
messages=messages,
max_tokens=max_tokens,
system_prompt=system_prompt,
include_context=include_context,
temperature=temperature,
stop_sequences=stop_sequences,
metadata=metadata,
model_preferences=model_preferences,
tools=tools,
tool_choice=tool_choice,
task=types.TaskMetadata(ttl=ttl),
)
),
types.CreateTaskResult,
)
task_id = create_result.task.task_id
async for _ in self.poll_task(task_id):
pass
return await self.get_task_result(task_id, types.CreateMessageResult)