Skip to content

Commit 6d37b16

Browse files
committed
fix: call correct handlers in task-required call path
1 parent 02e0485 commit 6d37b16

3 files changed

Lines changed: 174 additions & 2 deletions

File tree

src/server/mcp.ts

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,37 @@ export class McpServer {
8282
private _registeredTools: { [name: string]: RegisteredTool } = {};
8383
private _registeredPrompts: { [name: string]: RegisteredPrompt } = {};
8484
private _experimental?: { tasks: ExperimentalMcpServerTasks };
85+
private _taskToolMap: Map<string, string> = new Map();
8586

8687
constructor(serverInfo: Implementation, options?: ServerOptions) {
87-
this.server = new Server(serverInfo, options);
88+
const taskHandlerHooks = {
89+
getTask: async (taskId: string, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
90+
// taskStore is guaranteed to exist here because Protocol only calls hooks when taskStore is configured
91+
const taskStore = extra.taskStore!;
92+
const handler = this._getTaskHandler(taskId);
93+
if (handler) {
94+
return await handler.getTask({ ...extra, taskId, taskStore });
95+
}
96+
return await taskStore.getTask(taskId);
97+
},
98+
getTaskResult: async (taskId: string, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
99+
const taskStore = extra.taskStore!;
100+
const handler = this._getTaskHandler(taskId);
101+
if (handler) {
102+
return await handler.getTaskResult({ ...extra, taskId, taskStore });
103+
}
104+
return await taskStore.getTaskResult(taskId);
105+
}
106+
};
107+
this.server = new Server(serverInfo, { ...options, taskHandlerHooks });
108+
}
109+
110+
private _getTaskHandler(taskId: string): ToolTaskHandler<ZodRawShapeCompat | undefined> | null {
111+
const toolName = this._taskToolMap.get(taskId);
112+
if (!toolName) return null;
113+
const tool = this._registeredTools[toolName];
114+
if (!tool || !('createTask' in (tool.handler as AnyToolHandler<ZodRawShapeCompat>))) return null;
115+
return tool.handler as ToolTaskHandler<ZodRawShapeCompat | undefined>;
88116
}
89117

90118
/**
@@ -215,6 +243,10 @@ export class McpServer {
215243

216244
// Return CreateTaskResult immediately for task requests
217245
if (isTaskRequest) {
246+
const taskResult = result as CreateTaskResult;
247+
if (taskResult.task?.taskId) {
248+
this._taskToolMap.set(taskResult.task.taskId, request.params.name);
249+
}
218250
return result;
219251
}
220252

src/shared/protocol.ts

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ export type ProtocolOptions = {
9898
* appropriately (e.g., by failing the task, dropping messages, etc.).
9999
*/
100100
maxTaskQueueSize?: number;
101+
/**
102+
* Optional hooks for customizing task request handling.
103+
* If a hook is provided, it fully owns the behavior (no fallback to TaskStore).
104+
*/
105+
taskHandlerHooks?: {
106+
/**
107+
* Called when tasks/get is received. If provided, must return the task.
108+
*/
109+
getTask?: (taskId: string, extra: RequestHandlerExtra<Request, Notification>) => Promise<GetTaskResult>;
110+
/**
111+
* Called when tasks/payload needs to retrieve the final result. If provided, must return the result.
112+
*/
113+
getTaskResult?: (taskId: string, extra: RequestHandlerExtra<Request, Notification>) => Promise<Result>;
114+
};
101115
};
102116

103117
/**
@@ -383,6 +397,16 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
383397
this._taskMessageQueue = _options?.taskMessageQueue;
384398
if (this._taskStore) {
385399
this.setRequestHandler(GetTaskRequestSchema, async (request, extra) => {
400+
// Use hook if provided, otherwise fall back to TaskStore
401+
if (_options?.taskHandlerHooks?.getTask) {
402+
const hookResult = await _options.taskHandlerHooks.getTask(
403+
request.params.taskId,
404+
extra as unknown as RequestHandlerExtra<Request, Notification>
405+
);
406+
// @ts-expect-error SendResultT cannot contain GetTaskResult
407+
return hookResult as SendResultT;
408+
}
409+
386410
const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId);
387411
if (!task) {
388412
throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found');
@@ -462,7 +486,13 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
462486

463487
// If task is terminal, return the result
464488
if (isTerminal(task.status)) {
465-
const result = await this._taskStore!.getTaskResult(taskId, extra.sessionId);
489+
// Use hook if provided, otherwise fall back to TaskStore
490+
const result = this._options?.taskHandlerHooks?.getTaskResult
491+
? await this._options.taskHandlerHooks.getTaskResult(
492+
taskId,
493+
extra as unknown as RequestHandlerExtra<Request, Notification>
494+
)
495+
: await this._taskStore!.getTaskResult(taskId, extra.sessionId);
466496

467497
this._clearTaskQueue(taskId);
468498

test/server/mcp.test.ts

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import {
66
CallToolResultSchema,
77
type CallToolResult,
88
CompleteResultSchema,
9+
CreateTaskResultSchema,
910
ElicitRequestSchema,
1011
GetPromptResultSchema,
12+
GetTaskResultSchema,
1113
ListPromptsResultSchema,
1214
ListResourcesResultSchema,
1315
ListResourceTemplatesResultSchema,
@@ -6859,5 +6861,113 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
68596861

68606862
taskStore.cleanup();
68616863
});
6864+
6865+
test('should call custom getTask and getTaskResult handlers when client polls task directly', async () => {
6866+
const taskStore = new InMemoryTaskStore();
6867+
6868+
const getTaskSpy = vi.fn();
6869+
const getTaskResultSpy = vi.fn();
6870+
let taskCreatedAt: number;
6871+
6872+
const mcpServer = new McpServer(
6873+
{
6874+
name: 'test server',
6875+
version: '1.0'
6876+
},
6877+
{
6878+
capabilities: {
6879+
tools: {},
6880+
tasks: {
6881+
requests: {
6882+
tools: {
6883+
call: {}
6884+
}
6885+
}
6886+
}
6887+
},
6888+
taskStore
6889+
}
6890+
);
6891+
6892+
const client = new Client(
6893+
{
6894+
name: 'test client',
6895+
version: '1.0'
6896+
},
6897+
{
6898+
capabilities: {
6899+
tasks: {
6900+
requests: {
6901+
tools: {
6902+
call: {}
6903+
}
6904+
}
6905+
}
6906+
}
6907+
}
6908+
);
6909+
6910+
mcpServer.experimental.tasks.registerToolTask(
6911+
'task-tool',
6912+
{
6913+
description: 'A task tool',
6914+
inputSchema: { data: z.string() },
6915+
execution: { taskSupport: 'required' }
6916+
},
6917+
{
6918+
createTask: async (_args, extra) => {
6919+
const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 });
6920+
taskCreatedAt = Date.now();
6921+
return { task };
6922+
},
6923+
getTask: async extra => {
6924+
getTaskSpy(extra.taskId);
6925+
// Complete the task after 50ms - only occurs if getTask is actually called
6926+
if (Date.now() - taskCreatedAt >= 50) {
6927+
await extra.taskStore.storeTaskResult(extra.taskId, 'completed', {
6928+
content: [{ type: 'text' as const, text: 'Done' }]
6929+
});
6930+
}
6931+
return await extra.taskStore.getTask(extra.taskId);
6932+
},
6933+
getTaskResult: async extra => {
6934+
getTaskResultSpy(extra.taskId);
6935+
return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult;
6936+
}
6937+
}
6938+
);
6939+
6940+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
6941+
await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]);
6942+
6943+
// Create task
6944+
const createResult = await client.request(
6945+
{
6946+
method: 'tools/call',
6947+
params: {
6948+
name: 'task-tool',
6949+
arguments: { data: 'test' },
6950+
task: { ttl: 60000 }
6951+
}
6952+
},
6953+
CreateTaskResultSchema
6954+
);
6955+
const taskId = createResult.task.taskId;
6956+
6957+
// Wait for task to be ready to complete
6958+
await new Promise(resolve => setTimeout(resolve, 60));
6959+
6960+
// Client directly calls tasks/get - should invoke custom handler which completes the task
6961+
const getResult = await client.request({ method: 'tasks/get', params: { taskId } }, GetTaskResultSchema);
6962+
expect(getResult.status).toBe('completed');
6963+
expect(getTaskSpy).toHaveBeenCalledWith(taskId);
6964+
6965+
// Client directly calls tasks/result - should invoke custom handler
6966+
const payloadResult = await client.request({ method: 'tasks/result', params: { taskId } }, CallToolResultSchema);
6967+
expect(payloadResult.content).toBeDefined();
6968+
expect(getTaskResultSpy).toHaveBeenCalledWith(taskId);
6969+
6970+
taskStore.cleanup();
6971+
});
68626972
});
68636973
});

0 commit comments

Comments
 (0)