Skip to content

Commit 83ca383

Browse files
committed
fix: call correct handlers in backwards-compat task poll path
Previously, the code only called the underlying task store, and the tests were not complex enough to validate that the handlers were being called, so they missed this.
1 parent 384311b commit 83ca383

2 files changed

Lines changed: 47 additions & 6 deletions

File tree

src/server/mcp.ts

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,27 +374,28 @@ export class McpServer {
374374
const handler = tool.handler as ToolTaskHandler<ZodRawShapeCompat | undefined>;
375375
const taskExtra = { ...extra, taskStore: extra.taskStore };
376376

377-
const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined
378-
? await Promise.resolve((handler as ToolTaskHandler<ZodRawShapeCompat>).createTask(args, taskExtra))
379-
: // eslint-disable-next-line @typescript-eslint/no-explicit-any
380-
await Promise.resolve(((handler as ToolTaskHandler<undefined>).createTask as any)(taskExtra));
377+
const wrappedHandler = toolTaskHandlerByArgs(handler, args);
378+
379+
const createTaskResult = await wrappedHandler.createTask(taskExtra);
381380

382381
// Poll until completion
383382
const taskId = createTaskResult.task.taskId;
383+
const taskExtraComplete = { ...extra, taskId, taskStore: extra.taskStore };
384384
let task = createTaskResult.task;
385385
const pollInterval = task.pollInterval ?? 5000;
386386

387387
while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') {
388388
await new Promise(resolve => setTimeout(resolve, pollInterval));
389-
const updatedTask = await extra.taskStore.getTask(taskId);
389+
const getTaskResult = await wrappedHandler.getTask(taskExtraComplete);
390+
const updatedTask = getTaskResult;
390391
if (!updatedTask) {
391392
throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`);
392393
}
393394
task = updatedTask;
394395
}
395396

396397
// Return the final result
397-
return (await extra.taskStore.getTaskResult(taskId)) as CallToolResult;
398+
return await wrappedHandler.getTaskResult(taskExtraComplete);
398399
}
399400

400401
private _completionHandlerInitialized = false;
@@ -1540,3 +1541,29 @@ const EMPTY_COMPLETION_RESULT: CompleteResult = {
15401541
hasMore: false
15411542
}
15421543
};
1544+
1545+
/**
1546+
* Wraps a tool task handler such that it can be used without checking if it needs to be called in a one-arg manner.
1547+
* @param handler The task handler to wrap.
1548+
* @param args The tool arguments.
1549+
* @returns A wrapped task handler for a tool, which only exposes a no-args interface.
1550+
*/
1551+
function toolTaskHandlerByArgs<Args extends AnySchema | ZodRawShapeCompat | undefined>(
1552+
handler: ToolTaskHandler<Args>,
1553+
args: unknown
1554+
): ToolTaskHandler<undefined> {
1555+
return {
1556+
createTask: extra =>
1557+
args // undefined only if tool.inputSchema is undefined
1558+
? Promise.resolve((handler as ToolTaskHandler<ZodRawShapeCompat>).createTask(args, extra))
1559+
: Promise.resolve((handler as ToolTaskHandler<undefined>).createTask(extra)),
1560+
getTask: extra =>
1561+
args
1562+
? (handler as ToolTaskHandler<ZodRawShapeCompat>).getTask(args, extra)
1563+
: (handler as ToolTaskHandler<undefined>).getTask(extra),
1564+
getTaskResult: extra =>
1565+
args
1566+
? (handler as ToolTaskHandler<ZodRawShapeCompat>).getTaskResult(args, extra)
1567+
: (handler as ToolTaskHandler<undefined>).getTaskResult(extra)
1568+
};
1569+
}

test/server/mcp.test.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6355,6 +6355,11 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
63556355
const taskStore = new InMemoryTaskStore();
63566356
const { releaseLatch, waitForLatch } = createLatch();
63576357

6358+
// Spies to verify handler invocations
6359+
const createTaskSpy = vi.fn();
6360+
const getTaskSpy = vi.fn();
6361+
const getTaskResultSpy = vi.fn();
6362+
63586363
const mcpServer = new McpServer(
63596364
{
63606365
name: 'test server',
@@ -6407,6 +6412,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
64076412
},
64086413
{
64096414
createTask: async ({ value }, extra) => {
6415+
createTaskSpy({ value });
64106416
const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 });
64116417

64126418
// Capture taskStore for use in setTimeout
@@ -6423,13 +6429,15 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
64236429
return { task };
64246430
},
64256431
getTask: async (_args, extra) => {
6432+
getTaskSpy(extra.taskId);
64266433
const task = await extra.taskStore.getTask(extra.taskId);
64276434
if (!task) {
64286435
throw new Error('Task not found');
64296436
}
64306437
return task;
64316438
},
64326439
getTaskResult: async (_value, extra) => {
6440+
getTaskResultSpy(extra.taskId);
64336441
const result = await extra.taskStore.getTaskResult(extra.taskId);
64346442
return result as CallToolResult;
64356443
}
@@ -6454,6 +6462,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
64546462
expect(result.content).toEqual([{ type: 'text' as const, text: 'Result: 42' }]);
64556463
expect(result).not.toHaveProperty('task');
64566464

6465+
// Verify all three handler methods were called
6466+
expect(createTaskSpy).toHaveBeenCalledOnce();
6467+
expect(createTaskSpy).toHaveBeenCalledWith({ value: 21 });
6468+
expect(getTaskSpy).toHaveBeenCalled(); // Called at least once during polling
6469+
expect(getTaskResultSpy).toHaveBeenCalledOnce();
6470+
64576471
// Wait for async operations to complete
64586472
await waitForLatch();
64596473
taskStore.cleanup();

0 commit comments

Comments
 (0)