Skip to content

Commit 55b1f06

Browse files
refactor(core): _wrapHandler hook so subclasses don't redeclare setRequestHandler (#1976)
Co-authored-by: Konstantin Konstantinov <KKonstantinov@users.noreply.github.com>
1 parent 7cccc2a commit 55b1f06

5 files changed

Lines changed: 134 additions & 90 deletions

File tree

.changeset/wraphandler-hook.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
'@modelcontextprotocol/core': patch
3+
'@modelcontextprotocol/client': patch
4+
'@modelcontextprotocol/server': patch
5+
---
6+
7+
refactor: subclasses override `_wrapHandler` hook instead of redeclaring `setRequestHandler`.

packages/client/src/client/client.ts

Lines changed: 35 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ import type {
66
ClientContext,
77
ClientNotification,
88
ClientRequest,
9-
ClientResult,
109
CompleteRequest,
1110
GetPromptRequest,
1211
Implementation,
12+
JSONRPCRequest,
1313
JsonSchemaType,
1414
JsonSchemaValidator,
1515
jsonSchemaValidator,
@@ -26,8 +26,7 @@ import type {
2626
ReadResourceRequest,
2727
RequestMethod,
2828
RequestOptions,
29-
RequestTypeMap,
30-
ResultTypeMap,
29+
Result,
3130
ServerCapabilities,
3231
SubscribeRequest,
3332
TaskManagerOptions,
@@ -200,6 +199,28 @@ export type ClientOptions = ProtocolOptions & {
200199
*
201200
* The client will automatically begin the initialization flow with the server when {@linkcode connect} is called.
202201
*
202+
* To handle server-initiated requests (sampling, elicitation, roots), call {@linkcode setRequestHandler}.
203+
* The client must declare the corresponding capability for the handler to be accepted. For
204+
* `sampling/createMessage` and `elicitation/create`, the handler is automatically wrapped with
205+
* schema validation for both the incoming request and the returned result.
206+
*
207+
* @example Handling a sampling request
208+
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
209+
* client.setRequestHandler('sampling/createMessage', async request => {
210+
* const lastMessage = request.params.messages.at(-1);
211+
* console.log('Sampling request:', lastMessage);
212+
*
213+
* // In production, send messages to your LLM here
214+
* return {
215+
* model: 'my-model',
216+
* role: 'assistant' as const,
217+
* content: {
218+
* type: 'text' as const,
219+
* text: 'Response from the model'
220+
* }
221+
* };
222+
* });
223+
* ```
203224
*/
204225
export class Client extends Protocol<ClientContext> {
205226
private _serverCapabilities?: ServerCapabilities;
@@ -308,37 +329,15 @@ export class Client extends Protocol<ClientContext> {
308329
}
309330

310331
/**
311-
* Registers a handler for server-initiated requests (sampling, elicitation, roots).
312-
* The client must declare the corresponding capability for the handler to be accepted.
313-
* Replaces any previously registered handler for the same method.
314-
*
315-
* For `sampling/createMessage` and `elicitation/create`, the handler is automatically
316-
* wrapped with schema validation for both the incoming request and the returned result.
317-
*
318-
* @example Handling a sampling request
319-
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
320-
* client.setRequestHandler('sampling/createMessage', async request => {
321-
* const lastMessage = request.params.messages.at(-1);
322-
* console.log('Sampling request:', lastMessage);
323-
*
324-
* // In production, send messages to your LLM here
325-
* return {
326-
* model: 'my-model',
327-
* role: 'assistant' as const,
328-
* content: {
329-
* type: 'text' as const,
330-
* text: 'Response from the model'
331-
* }
332-
* };
333-
* });
334-
* ```
332+
* Enforces client-side validation for `elicitation/create` and `sampling/createMessage`
333+
* regardless of how the handler was registered.
335334
*/
336-
public override setRequestHandler<M extends RequestMethod>(
337-
method: M,
338-
handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
339-
): void {
335+
protected override _wrapHandler(
336+
method: string,
337+
handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result>
338+
): (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result> {
340339
if (method === 'elicitation/create') {
341-
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
340+
return async (request, ctx) => {
342341
const validatedRequest = parseSchema(ElicitRequestSchema, request);
343342
if (!validatedRequest.success) {
344343
// Type guard: if success is false, error is guaranteed to exist
@@ -359,7 +358,7 @@ export class Client extends Protocol<ClientContext> {
359358
throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests');
360359
}
361360

362-
const result = await Promise.resolve(handler(request, ctx));
361+
const result = await handler(request, ctx);
363362

364363
// When task creation is requested, validate and return CreateTaskResult
365364
if (params.task) {
@@ -402,13 +401,10 @@ export class Client extends Protocol<ClientContext> {
402401

403402
return validatedResult;
404403
};
405-
406-
// Install the wrapped handler
407-
return super.setRequestHandler(method, wrappedHandler);
408404
}
409405

410406
if (method === 'sampling/createMessage') {
411-
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
407+
return async (request, ctx) => {
412408
const validatedRequest = parseSchema(CreateMessageRequestSchema, request);
413409
if (!validatedRequest.success) {
414410
const errorMessage =
@@ -418,7 +414,7 @@ export class Client extends Protocol<ClientContext> {
418414

419415
const { params } = validatedRequest.data;
420416

421-
const result = await Promise.resolve(handler(request, ctx));
417+
const result = await handler(request, ctx);
422418

423419
// When task creation is requested, validate and return CreateTaskResult
424420
if (params.task) {
@@ -445,13 +441,9 @@ export class Client extends Protocol<ClientContext> {
445441

446442
return validationResult.data;
447443
};
448-
449-
// Install the wrapped handler
450-
return super.setRequestHandler(method, wrappedHandler);
451444
}
452445

453-
// Other handlers use default behavior
454-
return super.setRequestHandler(method, handler);
446+
return handler;
455447
}
456448

457449
protected assertCapability(capability: keyof ServerCapabilities, method: string): void {

packages/core/src/shared/protocol.ts

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,15 +1007,30 @@ export abstract class Protocol<ContextT extends BaseContext> {
10071007
*/
10081008
setRequestHandler<M extends RequestMethod>(
10091009
method: M,
1010-
handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise<Result>
1010+
handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
10111011
): void {
10121012
this.assertRequestHandlerCapability(method);
10131013
const schema = getRequestSchema(method);
10141014

1015-
this._requestHandlers.set(method, (request, ctx) => {
1015+
const stored = (request: JSONRPCRequest, ctx: ContextT): Promise<Result> => {
10161016
const parsed = schema.parse(request) as RequestTypeMap[M];
10171017
return Promise.resolve(handler(parsed, ctx));
1018-
});
1018+
};
1019+
this._requestHandlers.set(method, this._wrapHandler(method, stored));
1020+
}
1021+
1022+
/**
1023+
* Hook for subclasses to wrap a registered request handler with role-specific
1024+
* validation or behavior (e.g. `Server` validates `tools/call` results, `Client`
1025+
* validates `elicitation/create` mode and result). The default implementation is identity.
1026+
*
1027+
* Subclasses overriding this hook avoid redeclaring `setRequestHandler` and its JSDoc.
1028+
*/
1029+
protected _wrapHandler(
1030+
_method: string,
1031+
handler: (request: JSONRPCRequest, ctx: ContextT) => Promise<Result>
1032+
): (request: JSONRPCRequest, ctx: ContextT) => Promise<Result> {
1033+
return handler;
10191034
}
10201035

10211036
/**
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { describe, expect, it } from 'vitest';
2+
3+
import { Protocol } from '../../src/shared/protocol.js';
4+
import type { BaseContext, JSONRPCRequest, Result } from '../../src/exports/public/index.js';
5+
6+
class TestProtocol extends Protocol<BaseContext> {
7+
protected buildContext(ctx: BaseContext): BaseContext {
8+
return ctx;
9+
}
10+
protected assertCapabilityForMethod(): void {}
11+
protected assertNotificationCapability(): void {}
12+
protected assertRequestHandlerCapability(): void {}
13+
protected assertTaskCapability(): void {}
14+
protected assertTaskHandlerCapability(): void {}
15+
}
16+
17+
describe('Protocol._wrapHandler', () => {
18+
it('routes setRequestHandler registration through _wrapHandler', () => {
19+
const seen: string[] = [];
20+
class SpyProtocol extends TestProtocol {
21+
protected override _wrapHandler(
22+
method: string,
23+
handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result>
24+
): (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result> {
25+
seen.push(method);
26+
return handler;
27+
}
28+
}
29+
const p = new SpyProtocol();
30+
seen.length = 0;
31+
p.setRequestHandler('tools/list', () => ({ tools: [] }));
32+
p.setRequestHandler('resources/list', () => ({ resources: [] }));
33+
expect(seen).toEqual(['tools/list', 'resources/list']);
34+
});
35+
});

packages/server/src/server/server.ts

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import type {
1212
Implementation,
1313
InitializeRequest,
1414
InitializeResult,
15+
JSONRPCRequest,
1516
JsonSchemaType,
1617
jsonSchemaValidator,
1718
ListRootsRequest,
@@ -23,12 +24,10 @@ import type {
2324
ProtocolOptions,
2425
RequestMethod,
2526
RequestOptions,
26-
RequestTypeMap,
2727
ResourceUpdatedNotification,
28-
ResultTypeMap,
28+
Result,
2929
ServerCapabilities,
3030
ServerContext,
31-
ServerResult,
3231
TaskManagerOptions,
3332
ToolResultContent,
3433
ToolUseContent
@@ -220,55 +219,51 @@ export class Server extends Protocol<ServerContext> {
220219
}
221220

222221
/**
223-
* Override request handler registration to enforce server-side validation for `tools/call`.
222+
* Enforces server-side validation for `tools/call` results regardless of how the
223+
* handler was registered.
224224
*/
225-
public override setRequestHandler<M extends RequestMethod>(
226-
method: M,
227-
handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
228-
): void {
229-
if (method === 'tools/call') {
230-
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise<ServerResult> => {
231-
const validatedRequest = parseSchema(CallToolRequestSchema, request);
232-
if (!validatedRequest.success) {
233-
const errorMessage =
234-
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
235-
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
236-
}
237-
238-
const { params } = validatedRequest.data;
225+
protected override _wrapHandler(
226+
method: string,
227+
handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result>
228+
): (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result> {
229+
if (method !== 'tools/call') {
230+
return handler;
231+
}
232+
return async (request, ctx) => {
233+
const validatedRequest = parseSchema(CallToolRequestSchema, request);
234+
if (!validatedRequest.success) {
235+
const errorMessage =
236+
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
237+
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
238+
}
239239

240-
const result = await Promise.resolve(handler(request, ctx));
240+
const { params } = validatedRequest.data;
241241

242-
// When task creation is requested, validate and return CreateTaskResult
243-
if (params.task) {
244-
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
245-
if (!taskValidationResult.success) {
246-
const errorMessage =
247-
taskValidationResult.error instanceof Error
248-
? taskValidationResult.error.message
249-
: String(taskValidationResult.error);
250-
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
251-
}
252-
return taskValidationResult.data;
253-
}
242+
const result = await handler(request, ctx);
254243

255-
// For non-task requests, validate against CallToolResultSchema
256-
const validationResult = parseSchema(CallToolResultSchema, result);
257-
if (!validationResult.success) {
244+
// When task creation is requested, validate and return CreateTaskResult
245+
if (params.task) {
246+
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
247+
if (!taskValidationResult.success) {
258248
const errorMessage =
259-
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
260-
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
249+
taskValidationResult.error instanceof Error
250+
? taskValidationResult.error.message
251+
: String(taskValidationResult.error);
252+
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
261253
}
254+
return taskValidationResult.data;
255+
}
262256

263-
return validationResult.data;
264-
};
265-
266-
// Install the wrapped handler
267-
return super.setRequestHandler(method, wrappedHandler);
268-
}
257+
// For non-task requests, validate against CallToolResultSchema
258+
const validationResult = parseSchema(CallToolResultSchema, result);
259+
if (!validationResult.success) {
260+
const errorMessage =
261+
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
262+
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
263+
}
269264

270-
// Other handlers use default behavior
271-
return super.setRequestHandler(method, handler);
265+
return validationResult.data;
266+
};
272267
}
273268

274269
protected assertCapabilityForMethod(method: RequestMethod): void {

0 commit comments

Comments
 (0)