Skip to content

Commit be7b9eb

Browse files
fix: schema-arg setRequestHandler bypassed per-method wrapping in Client/Server
Normalize schema-arg to method string + parse-wrapped handler, then fall through to the per-method dispatch (tools/call task validation, elicitation/create capability checks). Previously the schema form short-circuited via _registerCompatRequestHandler, so e.g. setRequestHandler(CallToolRequestSchema, h) and setRequestHandler('tools/call', h) had different runtime behavior. Also reword @deprecated message: schema form is not deprecated for non-spec methods (the method-string overload is constrained to RequestMethod), so the advice now reads 'For spec methods, pass the method string instead.'
1 parent 1c01219 commit be7b9eb

File tree

4 files changed

+77
-18
lines changed

4 files changed

+77
-18
lines changed

packages/client/src/client/client.ts

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ import {
5151
ElicitRequestSchema,
5252
ElicitResultSchema,
5353
EmptyResultSchema,
54+
extractMethodLiteral,
5455
extractTaskManagerOptions,
5556
GetPromptResultSchema,
5657
InitializeResultSchema,
@@ -342,19 +343,23 @@ export class Client extends Protocol<ClientContext> {
342343
method: M,
343344
handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
344345
): void;
345-
/** @deprecated Pass the method string instead. */
346+
/** @deprecated For spec methods, pass the method string instead. */
346347
public override setRequestHandler<T extends ZodLikeRequestSchema>(
347348
requestSchema: T,
348349
handler: (request: ReturnType<T['parse']>, ctx: ClientContext) => Result | Promise<Result>
349350
): void;
350-
public override setRequestHandler(method: string | ZodLikeRequestSchema, schemaHandler: unknown): void {
351-
if (isZodLikeSchema(method)) {
352-
return this._registerCompatRequestHandler(
353-
method,
354-
schemaHandler as (request: unknown, ctx: ClientContext) => Result | Promise<Result>
355-
);
351+
public override setRequestHandler(methodOrSchema: string | ZodLikeRequestSchema, schemaHandler: unknown): void {
352+
let method: string;
353+
let handler: (request: Request, ctx: ClientContext) => ClientResult | Promise<ClientResult>;
354+
if (isZodLikeSchema(methodOrSchema)) {
355+
const schema = methodOrSchema;
356+
const userHandler = schemaHandler as (request: unknown, ctx: ClientContext) => Result | Promise<Result>;
357+
method = extractMethodLiteral(schema);
358+
handler = (req, ctx) => userHandler(schema.parse(req), ctx);
359+
} else {
360+
method = methodOrSchema;
361+
handler = schemaHandler as (request: Request, ctx: ClientContext) => ClientResult | Promise<ClientResult>;
356362
}
357-
const handler = schemaHandler as (request: Request, ctx: ClientContext) => ClientResult | Promise<ClientResult>;
358363
if (method === 'elicitation/create') {
359364
const wrappedHandler = async (request: Request, ctx: ClientContext): Promise<ClientResult> => {
360365
const validatedRequest = parseSchema(ElicitRequestSchema, request);

packages/core/src/shared/protocol.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
10301030
method: M,
10311031
handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise<Result>
10321032
): void;
1033-
/** @deprecated Pass the method string instead. */
1033+
/** @deprecated For spec methods, pass the method string instead. */
10341034
setRequestHandler<T extends ZodLikeRequestSchema>(
10351035
requestSchema: T,
10361036
handler: (request: ReturnType<T['parse']>, ctx: ContextT) => Result | Promise<Result>
@@ -1096,7 +1096,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
10961096
method: M,
10971097
handler: (notification: NotificationTypeMap[M]) => void | Promise<void>
10981098
): void;
1099-
/** @deprecated Pass the method string instead. */
1099+
/** @deprecated For spec methods, pass the method string instead. */
11001100
setNotificationHandler<T extends ZodLikeRequestSchema>(
11011101
notificationSchema: T,
11021102
handler: (notification: ReturnType<T['parse']>) => void | Promise<void>

packages/server/src/server/server.ts

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import {
4646
CreateTaskResultSchema,
4747
ElicitResultSchema,
4848
EmptyResultSchema,
49+
extractMethodLiteral,
4950
extractTaskManagerOptions,
5051
isZodLikeSchema,
5152
LATEST_PROTOCOL_VERSION,
@@ -230,19 +231,23 @@ export class Server extends Protocol<ServerContext> {
230231
method: M,
231232
handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
232233
): void;
233-
/** @deprecated Pass the method string instead. */
234+
/** @deprecated For spec methods, pass the method string instead. */
234235
public override setRequestHandler<T extends ZodLikeRequestSchema>(
235236
requestSchema: T,
236237
handler: (request: ReturnType<T['parse']>, ctx: ServerContext) => Result | Promise<Result>
237238
): void;
238-
public override setRequestHandler(method: string | ZodLikeRequestSchema, schemaHandler: unknown): void {
239-
if (isZodLikeSchema(method)) {
240-
return this._registerCompatRequestHandler(
241-
method,
242-
schemaHandler as (request: unknown, ctx: ServerContext) => Result | Promise<Result>
243-
);
239+
public override setRequestHandler(methodOrSchema: string | ZodLikeRequestSchema, schemaHandler: unknown): void {
240+
let method: string;
241+
let handler: (request: Request, ctx: ServerContext) => ServerResult | Promise<ServerResult>;
242+
if (isZodLikeSchema(methodOrSchema)) {
243+
const schema = methodOrSchema;
244+
const userHandler = schemaHandler as (request: unknown, ctx: ServerContext) => Result | Promise<Result>;
245+
method = extractMethodLiteral(schema);
246+
handler = (req, ctx) => userHandler(schema.parse(req), ctx);
247+
} else {
248+
method = methodOrSchema;
249+
handler = schemaHandler as (request: Request, ctx: ServerContext) => ServerResult | Promise<ServerResult>;
244250
}
245-
const handler = schemaHandler as (request: Request, ctx: ServerContext) => ServerResult | Promise<ServerResult>;
246251
if (method === 'tools/call') {
247252
const wrappedHandler = async (request: Request, ctx: ServerContext): Promise<ServerResult> => {
248253
const validatedRequest = parseSchema(CallToolRequestSchema, request);
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import { describe, expect, it } from 'vitest';
2+
3+
import { CallToolRequestSchema, InMemoryTransport } from '@modelcontextprotocol/core';
4+
5+
import { Server } from '../../src/server/server.js';
6+
7+
/**
8+
* Regression test: setRequestHandler(CallToolRequestSchema, h) and
9+
* setRequestHandler('tools/call', h) must apply the same per-method
10+
* wrapping (task-result validation when params.task is set).
11+
*/
12+
describe('Server.setRequestHandler — Zod-schema form parity', () => {
13+
async function setup(register: (s: Server) => void) {
14+
const server = new Server({ name: 't', version: '1.0' }, { capabilities: { tools: {} } });
15+
register(server);
16+
const [ct, st] = InMemoryTransport.createLinkedPair();
17+
await server.connect(st);
18+
await ct.start();
19+
return { ct };
20+
}
21+
22+
async function callToolWithTask(ct: InMemoryTransport): Promise<{ result?: unknown; error?: unknown }> {
23+
return await new Promise(resolve => {
24+
ct.onmessage = m => {
25+
const msg = m as { result?: unknown; error?: unknown };
26+
if ('result' in msg || 'error' in msg) resolve(msg);
27+
};
28+
ct.send({
29+
jsonrpc: '2.0',
30+
id: 1,
31+
method: 'tools/call',
32+
params: { name: 'x', arguments: {}, task: { ttl: 1000 } }
33+
});
34+
});
35+
}
36+
37+
it('schema form gets the same task-result validation as string form', async () => {
38+
const invalidTaskResult = { content: [{ type: 'text' as const, text: 'not a task result' }] };
39+
40+
const viaString = await setup(s => s.setRequestHandler('tools/call', () => invalidTaskResult));
41+
const viaSchema = await setup(s => s.setRequestHandler(CallToolRequestSchema, () => invalidTaskResult));
42+
43+
const stringRes = await callToolWithTask(viaString.ct);
44+
const schemaRes = await callToolWithTask(viaSchema.ct);
45+
46+
expect(stringRes.error).toBeDefined();
47+
expect(schemaRes.error).toEqual(stringRes.error);
48+
});
49+
});

0 commit comments

Comments
 (0)