Skip to content

Commit 7503353

Browse files
LucaButBoringfelixweinberger
authored andcommitted
fix: call correct handlers in backwards-compat task poll path
Previously, the code only called the underlying task store, and the tests were not complex enough to validate that the handlers were being called, so they missed this.
1 parent ddadaa6 commit 7503353

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

src/server/mcp.ts

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -374,27 +374,28 @@ export class McpServer {
374374
const handler = tool.handler as ToolTaskHandler<ZodRawShapeCompat | undefined>;
375375
const taskExtra = { ...extra, taskStore: extra.taskStore };
376376

377-
const createTaskResult: CreateTaskResult = args // undefined only if tool.inputSchema is undefined
378-
? await Promise.resolve((handler as ToolTaskHandler<ZodRawShapeCompat>).createTask(args, taskExtra))
379-
: // eslint-disable-next-line @typescript-eslint/no-explicit-any
380-
await Promise.resolve(((handler as ToolTaskHandler<undefined>).createTask as any)(taskExtra));
377+
const wrappedHandler = toolTaskHandlerByArgs(handler, args);
378+
379+
const createTaskResult = await wrappedHandler.createTask(taskExtra);
381380

382381
// Poll until completion
383382
const taskId = createTaskResult.task.taskId;
383+
const taskExtraComplete = { ...extra, taskId, taskStore: extra.taskStore };
384384
let task = createTaskResult.task;
385385
const pollInterval = task.pollInterval ?? 5000;
386386

387387
while (task.status !== 'completed' && task.status !== 'failed' && task.status !== 'cancelled') {
388388
await new Promise(resolve => setTimeout(resolve, pollInterval));
389-
const updatedTask = await extra.taskStore.getTask(taskId);
389+
const getTaskResult = await wrappedHandler.getTask(taskExtraComplete);
390+
const updatedTask = getTaskResult;
390391
if (!updatedTask) {
391392
throw new McpError(ErrorCode.InternalError, `Task ${taskId} not found during polling`);
392393
}
393394
task = updatedTask;
394395
}
395396

396397
// Return the final result
397-
return (await extra.taskStore.getTaskResult(taskId)) as CallToolResult;
398+
return await wrappedHandler.getTaskResult(taskExtraComplete);
398399
}
399400

400401
private _completionHandlerInitialized = false;
@@ -1545,3 +1546,29 @@ const EMPTY_COMPLETION_RESULT: CompleteResult = {
15451546
hasMore: false
15461547
}
15471548
};
1549+
1550+
/**
1551+
* Wraps a tool task handler such that it can be used without checking if it needs to be called in a one-arg manner.
1552+
* @param handler The task handler to wrap.
1553+
* @param args The tool arguments.
1554+
* @returns A wrapped task handler for a tool, which only exposes a no-args interface.
1555+
*/
1556+
function toolTaskHandlerByArgs<Args extends AnySchema | ZodRawShapeCompat | undefined>(
1557+
handler: ToolTaskHandler<Args>,
1558+
args: unknown
1559+
): ToolTaskHandler<undefined> {
1560+
return {
1561+
createTask: extra =>
1562+
args // undefined only if tool.inputSchema is undefined
1563+
? Promise.resolve((handler as ToolTaskHandler<ZodRawShapeCompat>).createTask(args, extra))
1564+
: Promise.resolve((handler as ToolTaskHandler<undefined>).createTask(extra)),
1565+
getTask: extra =>
1566+
args
1567+
? (handler as ToolTaskHandler<ZodRawShapeCompat>).getTask(args, extra)
1568+
: (handler as ToolTaskHandler<undefined>).getTask(extra),
1569+
getTaskResult: extra =>
1570+
args
1571+
? (handler as ToolTaskHandler<ZodRawShapeCompat>).getTaskResult(args, extra)
1572+
: (handler as ToolTaskHandler<undefined>).getTaskResult(extra)
1573+
};
1574+
}

test/server/mcp.test.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6386,6 +6386,11 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
63866386
const taskStore = new InMemoryTaskStore();
63876387
const { releaseLatch, waitForLatch } = createLatch();
63886388

6389+
// Spies to verify handler invocations
6390+
const createTaskSpy = vi.fn();
6391+
const getTaskSpy = vi.fn();
6392+
const getTaskResultSpy = vi.fn();
6393+
63896394
const mcpServer = new McpServer(
63906395
{
63916396
name: 'test server',
@@ -6438,6 +6443,7 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
64386443
},
64396444
{
64406445
createTask: async ({ value }, extra) => {
6446+
createTaskSpy({ value });
64416447
const task = await extra.taskStore.createTask({ ttl: 60000, pollInterval: 100 });
64426448

64436449
// Capture taskStore for use in setTimeout
@@ -6454,13 +6460,15 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
64546460
return { task };
64556461
},
64566462
getTask: async (_args, extra) => {
6463+
getTaskSpy(extra.taskId);
64576464
const task = await extra.taskStore.getTask(extra.taskId);
64586465
if (!task) {
64596466
throw new Error('Task not found');
64606467
}
64616468
return task;
64626469
},
64636470
getTaskResult: async (_value, extra) => {
6471+
getTaskResultSpy(extra.taskId);
64646472
const result = await extra.taskStore.getTaskResult(extra.taskId);
64656473
return result as CallToolResult;
64666474
}
@@ -6485,6 +6493,12 @@ describe.each(zodTestMatrix)('$zodVersionLabel', (entry: ZodMatrixEntry) => {
64856493
expect(result.content).toEqual([{ type: 'text' as const, text: 'Result: 42' }]);
64866494
expect(result).not.toHaveProperty('task');
64876495

6496+
// Verify all three handler methods were called
6497+
expect(createTaskSpy).toHaveBeenCalledOnce();
6498+
expect(createTaskSpy).toHaveBeenCalledWith({ value: 21 });
6499+
expect(getTaskSpy).toHaveBeenCalled(); // Called at least once during polling
6500+
expect(getTaskResultSpy).toHaveBeenCalledOnce();
6501+
64886502
// Wait for async operations to complete
64896503
await waitForLatch();
64906504
taskStore.cleanup();

0 commit comments

Comments
 (0)