-
Notifications
You must be signed in to change notification settings - Fork 11k
Expand file tree
/
Copy pathtrigger-sampling-request-async.ts
More file actions
236 lines (217 loc) · 7.31 KB
/
Copy pathtrigger-sampling-request-async.ts
File metadata and controls
236 lines (217 loc) · 7.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import {
CallToolResult,
CreateMessageRequest,
} from "@modelcontextprotocol/sdk/types.js";
import { z } from "zod";
// Tool input schema
const TriggerSamplingRequestAsyncSchema = z.object({
prompt: z.string().describe("The prompt to send to the LLM"),
maxTokens: z
.number()
.default(100)
.describe("Maximum number of tokens to generate"),
});
// Tool configuration
const name = "trigger-sampling-request-async";
const config = {
title: "Trigger Async Sampling Request Tool",
description:
"Trigger an async sampling request that the CLIENT executes as a background task. " +
"Demonstrates bidirectional MCP tasks where the server sends a request and the client " +
"executes it asynchronously, allowing the server to poll for progress and results.",
inputSchema: TriggerSamplingRequestAsyncSchema,
annotations: {
readOnlyHint: false,
destructiveHint: false,
idempotentHint: false,
openWorldHint: true,
},
};
// Poll interval in milliseconds
const POLL_INTERVAL = 1000;
// Maximum poll attempts before timeout
const MAX_POLL_ATTEMPTS = 60;
/**
* Registers the 'trigger-sampling-request-async' tool.
*
* This tool demonstrates bidirectional MCP tasks:
* - Server sends sampling request to client with task metadata
* - Client creates a task and returns CreateTaskResult
* - Server polls client's tasks/get endpoint for status
* - Server fetches final result from client's tasks/result endpoint
*
* @param {McpServer} server - The McpServer instance where the tool will be registered.
*/
export const registerTriggerSamplingRequestAsyncTool = (server: McpServer) => {
// Check client capabilities
const clientCapabilities = server.server.getClientCapabilities() || {};
// Client must support sampling AND tasks.requests.sampling
const clientSupportsSampling = clientCapabilities.sampling !== undefined;
const clientTasksCapability = clientCapabilities.tasks as
| {
requests?: { sampling?: { createMessage?: object } };
}
| undefined;
const clientSupportsAsyncSampling =
clientTasksCapability?.requests?.sampling?.createMessage !== undefined;
if (clientSupportsSampling && clientSupportsAsyncSampling) {
server.registerTool(
name,
config,
async (args, extra): Promise<CallToolResult> => {
const validatedArgs = TriggerSamplingRequestAsyncSchema.parse(args);
const { prompt, maxTokens } = validatedArgs;
// Create the sampling request WITH task metadata
// The params.task field signals to the client that this should be executed as a task
const request: CreateMessageRequest & {
params: { task?: { ttl: number } };
} = {
method: "sampling/createMessage",
params: {
task: {
ttl: 300000, // 5 minutes
},
messages: [
{
role: "user",
content: {
type: "text",
text: `Resource ${name} context: ${prompt}`,
},
},
],
systemPrompt: "You are a helpful test server.",
maxTokens,
temperature: 0.7,
},
};
// Send the sampling request
// Client may return either:
// - CreateMessageResult (synchronous execution)
// - CreateTaskResult (task-based execution with { task } object)
const samplingResponse = await extra.sendRequest(
request,
z.union([
// CreateTaskResult - client created a task
z.object({
task: z.object({
taskId: z.string(),
status: z.string(),
pollInterval: z.number().optional(),
statusMessage: z.string().optional(),
}),
}),
// CreateMessageResult - synchronous execution
z.object({
role: z.string(),
content: z.any(),
model: z.string(),
stopReason: z.string().optional(),
}),
])
);
// Check if client returned CreateTaskResult (has task object)
const isTaskResult =
"task" in samplingResponse && samplingResponse.task;
if (!isTaskResult) {
// Client executed synchronously - return the direct response
return {
content: [
{
type: "text",
text: `[SYNC] Client executed synchronously:\n${JSON.stringify(
samplingResponse,
null,
2
)}`,
},
],
};
}
const taskId = samplingResponse.task.taskId;
const statusMessages: string[] = [];
statusMessages.push(`Task created: ${taskId}`);
// Poll for task completion
let attempts = 0;
let taskStatus = samplingResponse.task.status;
let taskStatusMessage: string | undefined;
while (
taskStatus !== "completed" &&
taskStatus !== "failed" &&
taskStatus !== "cancelled" &&
attempts < MAX_POLL_ATTEMPTS
) {
// Wait before polling
await new Promise((resolve) => setTimeout(resolve, POLL_INTERVAL));
attempts++;
// Get task status from client
const pollResult = await extra.sendRequest(
{
method: "tasks/get",
params: { taskId },
},
z
.object({
status: z.string(),
statusMessage: z.string().optional(),
})
.passthrough()
);
taskStatus = pollResult.status;
taskStatusMessage = pollResult.statusMessage;
statusMessages.push(
`Poll ${attempts}: ${taskStatus}${
taskStatusMessage ? ` - ${taskStatusMessage}` : ""
}`
);
}
// Check for timeout
if (attempts >= MAX_POLL_ATTEMPTS) {
return {
content: [
{
type: "text",
text: `[TIMEOUT] Task timed out after ${MAX_POLL_ATTEMPTS} poll attempts\n\nProgress:\n${statusMessages.join(
"\n"
)}`,
},
],
};
}
// Check for failure/cancellation
if (taskStatus === "failed" || taskStatus === "cancelled") {
return {
content: [
{
type: "text",
text: `[${taskStatus.toUpperCase()}] ${
taskStatusMessage || "No message"
}\n\nProgress:\n${statusMessages.join("\n")}`,
},
],
};
}
// Fetch the final result
const result = await extra.sendRequest(
{
method: "tasks/result",
params: { taskId },
},
z.any()
);
// Return the result with status history
return {
content: [
{
type: "text",
text: `[COMPLETED] Async sampling completed!\n\n**Progress:**\n${statusMessages.join(
"\n"
)}\n\n**Result:**\n${JSON.stringify(result, null, 2)}`,
},
],
};
}
);
}
};