Skip to content

Commit 8ae7217

Browse files
LucaButBoringfelixweinberger
authored andcommitted
fix: call correct handlers in task-required call path
1 parent fbe5df4 commit 8ae7217

File tree

3 files changed

+174
-2
lines changed

3 files changed

+174
-2
lines changed

src/server/mcp.ts

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,37 @@ export class McpServer {
8282
private _registeredTools: { [name: string]: RegisteredTool } = {};
8383
private _registeredPrompts: { [name: string]: RegisteredPrompt } = {};
8484
private _experimental?: { tasks: ExperimentalMcpServerTasks };
85+
private _taskToolMap: Map<string, string> = new Map();
8586

8687
constructor(serverInfo: Implementation, options?: ServerOptions) {
87-
this.server = new Server(serverInfo, options);
88+
const taskHandlerHooks = {
89+
getTask: async (taskId: string, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
90+
// taskStore is guaranteed to exist here because Protocol only calls hooks when taskStore is configured
91+
const taskStore = extra.taskStore!;
92+
const handler = this._getTaskHandler(taskId);
93+
if (handler) {
94+
return await handler.getTask({ ...extra, taskId, taskStore });
95+
}
96+
return await taskStore.getTask(taskId);
97+
},
98+
getTaskResult: async (taskId: string, extra: RequestHandlerExtra<ServerRequest, ServerNotification>) => {
99+
const taskStore = extra.taskStore!;
100+
const handler = this._getTaskHandler(taskId);
101+
if (handler) {
102+
return await handler.getTaskResult({ ...extra, taskId, taskStore });
103+
}
104+
return await taskStore.getTaskResult(taskId);
105+
}
106+
};
107+
this.server = new Server(serverInfo, { ...options, taskHandlerHooks });
108+
}
109+
110+
private _getTaskHandler(taskId: string): ToolTaskHandler<ZodRawShapeCompat | undefined> | null {
111+
const toolName = this._taskToolMap.get(taskId);
112+
if (!toolName) return null;
113+
const tool = this._registeredTools[toolName];
114+
if (!tool || !('createTask' in (tool.handler as AnyToolHandler<ZodRawShapeCompat>))) return null;
115+
return tool.handler as ToolTaskHandler<ZodRawShapeCompat | undefined>;
88116
}
89117

90118
/**
@@ -215,6 +243,10 @@ export class McpServer {
215243

216244
// Return CreateTaskResult immediately for task requests
217245
if (isTaskRequest) {
246+
const taskResult = result as CreateTaskResult;
247+
if (taskResult.task?.taskId) {
248+
this._taskToolMap.set(taskResult.task.taskId, request.params.name);
249+
}
218250
return result;
219251
}
220252

src/shared/protocol.ts

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ export type ProtocolOptions = {
9898
* appropriately (e.g., by failing the task, dropping messages, etc.).
9999
*/
100100
maxTaskQueueSize?: number;
101+
/**
102+
* Optional hooks for customizing task request handling.
103+
* If a hook is provided, it fully owns the behavior (no fallback to TaskStore).
104+
*/
105+
taskHandlerHooks?: {
106+
/**
107+
* Called when tasks/get is received. If provided, must return the task.
108+
*/
109+
getTask?: (taskId: string, extra: RequestHandlerExtra<Request, Notification>) => Promise<GetTaskResult>;
110+
/**
111+
* Called when tasks/payload needs to retrieve the final result. If provided, must return the result.
112+
*/
113+
getTaskResult?: (taskId: string, extra: RequestHandlerExtra<Request, Notification>) => Promise<Result>;
114+
};
101115
};
102116

103117
/**
@@ -383,6 +397,16 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
383397
this._taskMessageQueue = _options?.taskMessageQueue;
384398
if (this._taskStore) {
385399
this.setRequestHandler(GetTaskRequestSchema, async (request, extra) => {
400+
// Use hook if provided, otherwise fall back to TaskStore
401+
if (_options?.taskHandlerHooks?.getTask) {
402+
const hookResult = await _options.taskHandlerHooks.getTask(
403+
request.params.taskId,
404+
extra as unknown as RequestHandlerExtra<Request, Notification>
405+
);
406+
// @ts-expect-error SendResultT cannot contain GetTaskResult
407+
return hookResult as SendResultT;
408+
}
409+
386410
const task = await this._taskStore!.getTask(request.params.taskId, extra.sessionId);
387411
if (!task) {
388412
throw new McpError(ErrorCode.InvalidParams, 'Failed to retrieve task: Task not found');
@@ -462,7 +486,13 @@ export abstract class Protocol<SendRequestT extends Request, SendNotificationT e
462486

463487
// If task is terminal, return the result
464488
if (isTerminal(task.status)) {
465-
const result = await this._taskStore!.getTaskResult(taskId, extra.sessionId);
489+
// Use hook if provided, otherwise fall back to TaskStore
490+
const result = this._options?.taskHandlerHooks?.getTaskResult
491+
? await this._options.taskHandlerHooks.getTaskResult(
492+
taskId,
493+
extra as unknown as RequestHandlerExtra<Request, Notification>
494+
)
495+
: await this._taskStore!.getTaskResult(taskId, extra.sessionId);
466496

467497
this._clearTaskQueue(taskId);
468498

test/server/mcp.test.ts

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import {
66
CallToolResultSchema,
77
type CallToolResult,
88
CompleteResultSchema,
9+
CreateTaskResultSchema,
910
ElicitRequestSchema,
1011
GetPromptResultSchema,
12+
GetTaskResultSchema,
1113
ListPromptsResultSchema,
1214
ListResourcesResultSchema,
1315
ListResourceTemplatesResultSchema,
@@ -6890,5 +6892,113 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
68906892

68916893
taskStore.cleanup();
68926894
});
6895+
6896+
test('should call custom getTask and getTaskResult handlers when client polls task directly', async () => {
6897+
const taskStore = new InMemoryTaskStore();
6898+
6899+
const getTaskSpy = vi.fn();
6900+
const getTaskResultSpy = vi.fn();
6901+
let taskCreatedAt: number;
6902+
6903+
const mcpServer = new McpServer(
6904+
{
6905+
name: 'test server',
6906+
version: '1.0'
6907+
},
6908+
{
6909+
capabilities: {
6910+
tools: {},
6911+
tasks: {
6912+
requests: {
6913+
tools: {
6914+
call: {}
6915+
}
6916+
}
6917+
}
6918+
},
6919+
taskStore
6920+
}
6921+
);
6922+
6923+
const client = new Client(
6924+
{
6925+
name: 'test client',
6926+
version: '1.0'
6927+
},
6928+
{
6929+
capabilities: {
6930+
tasks: {
6931+
requests: {
6932+
tools: {
6933+
call: {}
6934+
}
6935+
}
6936+
}
6937+
}
6938+
}
6939+
);
6940+
6941+
mcpServer.experimental.tasks.registerToolTask(
6942+
'task-tool',
6943+
{
6944+
description: 'A task tool',
6945+
inputSchema: { data: z.string() },
6946+
execution: { taskSupport: 'required' }
6947+
},
6948+
{
6949+
createTask: async (_args, extra) => {
6950+
const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 });
6951+
taskCreatedAt = Date.now();
6952+
return { task };
6953+
},
6954+
getTask: async extra => {
6955+
getTaskSpy(extra.taskId);
6956+
// Complete the task after 50ms - only occurs if getTask is actually called
6957+
if (Date.now() - taskCreatedAt >= 50) {
6958+
await extra.taskStore.storeTaskResult(extra.taskId, 'completed', {
6959+
content: [{ type: 'text' as const, text: 'Done' }]
6960+
});
6961+
}
6962+
return await extra.taskStore.getTask(extra.taskId);
6963+
},
6964+
getTaskResult: async extra => {
6965+
getTaskResultSpy(extra.taskId);
6966+
return (await extra.taskStore.getTaskResult(extra.taskId)) as CallToolResult;
6967+
}
6968+
}
6969+
);
6970+
6971+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
6972+
await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]);
6973+
6974+
// Create task
6975+
const createResult = await client.request(
6976+
{
6977+
method: 'tools/call',
6978+
params: {
6979+
name: 'task-tool',
6980+
arguments: { data: 'test' },
6981+
task: { ttl: 60000 }
6982+
}
6983+
},
6984+
CreateTaskResultSchema
6985+
);
6986+
const taskId = createResult.task.taskId;
6987+
6988+
// Wait for task to be ready to complete
6989+
await new Promise(resolve => setTimeout(resolve, 60));
6990+
6991+
// Client directly calls tasks/get - should invoke custom handler which completes the task
6992+
const getResult = await client.request({ method: 'tasks/get', params: { taskId } }, GetTaskResultSchema);
6993+
expect(getResult.status).toBe('completed');
6994+
expect(getTaskSpy).toHaveBeenCalledWith(taskId);
6995+
6996+
// Client directly calls tasks/result - should invoke custom handler
6997+
const payloadResult = await client.request({ method: 'tasks/result', params: { taskId } }, CallToolResultSchema);
6998+
expect(payloadResult.content).toBeDefined();
6999+
expect(getTaskResultSpy).toHaveBeenCalledWith(taskId);
7000+
7001+
taskStore.cleanup();
7002+
});
68937003
});
68947004
});

0 commit comments

Comments
 (0)