Skip to content

Commit 8f12176

Browse files
refactor(core): subclasses override _wrapHandler hook instead of redeclaring setRequestHandler
1 parent 2a7611d commit 8f12176

5 files changed

Lines changed: 137 additions & 96 deletions

File tree

.changeset/wraphandler-hook.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
'@modelcontextprotocol/core': patch
3+
'@modelcontextprotocol/client': patch
4+
'@modelcontextprotocol/server': patch
5+
---
6+
7+
refactor: subclasses override `_wrapHandler` hook instead of redeclaring `setRequestHandler`.

packages/client/src/client/client.ts

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ import type {
66
ClientContext,
77
ClientNotification,
88
ClientRequest,
9-
ClientResult,
109
CompleteRequest,
1110
GetPromptRequest,
1211
Implementation,
12+
JSONRPCRequest,
1313
JsonSchemaType,
1414
JsonSchemaValidator,
1515
jsonSchemaValidator,
@@ -24,10 +24,8 @@ import type {
2424
NotificationMethod,
2525
ProtocolOptions,
2626
ReadResourceRequest,
27-
RequestMethod,
2827
RequestOptions,
29-
RequestTypeMap,
30-
ResultTypeMap,
28+
Result,
3129
ServerCapabilities,
3230
SubscribeRequest,
3331
TaskManagerOptions,
@@ -200,6 +198,28 @@ export type ClientOptions = ProtocolOptions & {
200198
*
201199
* The client will automatically begin the initialization flow with the server when {@linkcode connect} is called.
202200
*
201+
* To handle server-initiated requests (sampling, elicitation, roots), call {@linkcode setRequestHandler}.
202+
* The client must declare the corresponding capability for the handler to be accepted. For
203+
* `sampling/createMessage` and `elicitation/create`, the handler is automatically wrapped with
204+
* schema validation for both the incoming request and the returned result.
205+
*
206+
* @example Handling a sampling request
207+
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
208+
* client.setRequestHandler('sampling/createMessage', async request => {
209+
* const lastMessage = request.params.messages.at(-1);
210+
* console.log('Sampling request:', lastMessage);
211+
*
212+
* // In production, send messages to your LLM here
213+
* return {
214+
* model: 'my-model',
215+
* role: 'assistant' as const,
216+
* content: {
217+
* type: 'text' as const,
218+
* text: 'Response from the model'
219+
* }
220+
* };
221+
* });
222+
* ```
203223
*/
204224
export class Client extends Protocol<ClientContext> {
205225
private _serverCapabilities?: ServerCapabilities;
@@ -308,37 +328,15 @@ export class Client extends Protocol<ClientContext> {
308328
}
309329

310330
/**
311-
* Registers a handler for server-initiated requests (sampling, elicitation, roots).
312-
* The client must declare the corresponding capability for the handler to be accepted.
313-
* Replaces any previously registered handler for the same method.
314-
*
315-
* For `sampling/createMessage` and `elicitation/create`, the handler is automatically
316-
* wrapped with schema validation for both the incoming request and the returned result.
317-
*
318-
* @example Handling a sampling request
319-
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
320-
* client.setRequestHandler('sampling/createMessage', async request => {
321-
* const lastMessage = request.params.messages.at(-1);
322-
* console.log('Sampling request:', lastMessage);
323-
*
324-
* // In production, send messages to your LLM here
325-
* return {
326-
* model: 'my-model',
327-
* role: 'assistant' as const,
328-
* content: {
329-
* type: 'text' as const,
330-
* text: 'Response from the model'
331-
* }
332-
* };
333-
* });
334-
* ```
331+
* Enforces client-side validation for `elicitation/create` and `sampling/createMessage`
332+
* regardless of how the handler was registered.
335333
*/
336-
public override setRequestHandler<M extends RequestMethod>(
337-
method: M,
338-
handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
339-
): void {
334+
protected override _wrapHandler(
335+
method: string,
336+
handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result>
337+
): (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result> {
340338
if (method === 'elicitation/create') {
341-
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
339+
return async (request, ctx) => {
342340
const validatedRequest = parseSchema(ElicitRequestSchema, request);
343341
if (!validatedRequest.success) {
344342
// Type guard: if success is false, error is guaranteed to exist
@@ -359,7 +357,7 @@ export class Client extends Protocol<ClientContext> {
359357
throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests');
360358
}
361359

362-
const result = await Promise.resolve(handler(request, ctx));
360+
const result = await handler(request, ctx);
363361

364362
// When task creation is requested, validate and return CreateTaskResult
365363
if (params.task) {
@@ -402,13 +400,10 @@ export class Client extends Protocol<ClientContext> {
402400

403401
return validatedResult;
404402
};
405-
406-
// Install the wrapped handler
407-
return super.setRequestHandler(method, wrappedHandler);
408403
}
409404

410405
if (method === 'sampling/createMessage') {
411-
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
406+
return async (request, ctx) => {
412407
const validatedRequest = parseSchema(CreateMessageRequestSchema, request);
413408
if (!validatedRequest.success) {
414409
const errorMessage =
@@ -418,7 +413,7 @@ export class Client extends Protocol<ClientContext> {
418413

419414
const { params } = validatedRequest.data;
420415

421-
const result = await Promise.resolve(handler(request, ctx));
416+
const result = await handler(request, ctx);
422417

423418
// When task creation is requested, validate and return CreateTaskResult
424419
if (params.task) {
@@ -445,13 +440,9 @@ export class Client extends Protocol<ClientContext> {
445440

446441
return validationResult.data;
447442
};
448-
449-
// Install the wrapped handler
450-
return super.setRequestHandler(method, wrappedHandler);
451443
}
452444

453-
// Other handlers use default behavior
454-
return super.setRequestHandler(method, handler);
445+
return handler;
455446
}
456447

457448
protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
@@ -578,7 +569,7 @@ export class Client extends Protocol<ClientContext> {
578569
return this._instructions;
579570
}
580571

581-
protected assertCapabilityForMethod(method: RequestMethod): void {
572+
protected assertCapabilityForMethod(method: string): void {
582573
switch (method as ClientRequest['method']) {
583574
case 'logging/setLevel': {
584575
if (!this._serverCapabilities?.logging) {
@@ -641,7 +632,7 @@ export class Client extends Protocol<ClientContext> {
641632
}
642633
}
643634

644-
protected assertNotificationCapability(method: NotificationMethod): void {
635+
protected assertNotificationCapability(method: string): void {
645636
switch (method as ClientNotification['method']) {
646637
case 'notifications/roots/list_changed': {
647638
if (!this._capabilities.roots?.listChanged) {

packages/core/src/shared/protocol.ts

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,10 +1012,25 @@ export abstract class Protocol<ContextT extends BaseContext> {
10121012
this.assertRequestHandlerCapability(method);
10131013
const schema = getRequestSchema(method);
10141014

1015-
this._requestHandlers.set(method, (request, ctx) => {
1015+
const stored = (request: JSONRPCRequest, ctx: ContextT): Promise<Result> => {
10161016
const parsed = schema.parse(request) as RequestTypeMap[M];
10171017
return Promise.resolve(handler(parsed, ctx));
1018-
});
1018+
};
1019+
this._requestHandlers.set(method, this._wrapHandler(method, stored));
1020+
}
1021+
1022+
/**
1023+
* Hook for subclasses to wrap a registered request handler with role-specific
1024+
* validation or behavior (e.g. `Server` validates `tools/call` results, `Client`
1025+
* validates `elicitation/create` mode and result). The default implementation is identity.
1026+
*
1027+
* Subclasses overriding this hook avoid redeclaring `setRequestHandler` and its JSDoc.
1028+
*/
1029+
protected _wrapHandler(
1030+
_method: string,
1031+
handler: (request: JSONRPCRequest, ctx: ContextT) => Promise<Result>
1032+
): (request: JSONRPCRequest, ctx: ContextT) => Promise<Result> {
1033+
return handler;
10191034
}
10201035

10211036
/**
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { describe, expect, it } from 'vitest';
2+
3+
import { Protocol } from '../../src/shared/protocol.js';
4+
import type { BaseContext, JSONRPCRequest, Result } from '../../src/exports/public/index.js';
5+
6+
class TestProtocol extends Protocol<BaseContext> {
7+
protected buildContext(ctx: BaseContext): BaseContext {
8+
return ctx;
9+
}
10+
protected assertCapabilityForMethod(): void {}
11+
protected assertNotificationCapability(): void {}
12+
protected assertRequestHandlerCapability(): void {}
13+
protected assertTaskCapability(): void {}
14+
protected assertTaskHandlerCapability(): void {}
15+
}
16+
17+
describe('Protocol._wrapHandler', () => {
18+
it('routes setRequestHandler registration through _wrapHandler', () => {
19+
const seen: string[] = [];
20+
class SpyProtocol extends TestProtocol {
21+
protected override _wrapHandler(
22+
method: string,
23+
handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result>
24+
): (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result> {
25+
seen.push(method);
26+
return handler;
27+
}
28+
}
29+
const p = new SpyProtocol();
30+
seen.length = 0;
31+
p.setRequestHandler('tools/list', () => ({ tools: [] }));
32+
p.setRequestHandler('resources/list', () => ({ resources: [] }));
33+
expect(seen).toEqual(['tools/list', 'resources/list']);
34+
});
35+
});

packages/server/src/server/server.ts

Lines changed: 41 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,20 @@ import type {
1212
Implementation,
1313
InitializeRequest,
1414
InitializeResult,
15+
JSONRPCRequest,
1516
JsonSchemaType,
1617
jsonSchemaValidator,
1718
ListRootsRequest,
1819
LoggingLevel,
1920
LoggingMessageNotification,
2021
MessageExtraInfo,
21-
NotificationMethod,
2222
NotificationOptions,
2323
ProtocolOptions,
24-
RequestMethod,
2524
RequestOptions,
26-
RequestTypeMap,
2725
ResourceUpdatedNotification,
28-
ResultTypeMap,
26+
Result,
2927
ServerCapabilities,
3028
ServerContext,
31-
ServerResult,
3229
TaskManagerOptions,
3330
ToolResultContent,
3431
ToolUseContent
@@ -220,58 +217,54 @@ export class Server extends Protocol<ServerContext> {
220217
}
221218

222219
/**
223-
* Override request handler registration to enforce server-side validation for `tools/call`.
220+
* Enforces server-side validation for `tools/call` results regardless of how the
221+
* handler was registered.
224222
*/
225-
public override setRequestHandler<M extends RequestMethod>(
226-
method: M,
227-
handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
228-
): void {
229-
if (method === 'tools/call') {
230-
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise<ServerResult> => {
231-
const validatedRequest = parseSchema(CallToolRequestSchema, request);
232-
if (!validatedRequest.success) {
233-
const errorMessage =
234-
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
235-
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
236-
}
237-
238-
const { params } = validatedRequest.data;
223+
protected override _wrapHandler(
224+
method: string,
225+
handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result>
226+
): (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result> {
227+
if (method !== 'tools/call') {
228+
return handler;
229+
}
230+
return async (request, ctx) => {
231+
const validatedRequest = parseSchema(CallToolRequestSchema, request);
232+
if (!validatedRequest.success) {
233+
const errorMessage =
234+
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
235+
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
236+
}
239237

240-
const result = await Promise.resolve(handler(request, ctx));
238+
const { params } = validatedRequest.data;
241239

242-
// When task creation is requested, validate and return CreateTaskResult
243-
if (params.task) {
244-
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
245-
if (!taskValidationResult.success) {
246-
const errorMessage =
247-
taskValidationResult.error instanceof Error
248-
? taskValidationResult.error.message
249-
: String(taskValidationResult.error);
250-
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
251-
}
252-
return taskValidationResult.data;
253-
}
240+
const result = await handler(request, ctx);
254241

255-
// For non-task requests, validate against CallToolResultSchema
256-
const validationResult = parseSchema(CallToolResultSchema, result);
257-
if (!validationResult.success) {
242+
// When task creation is requested, validate and return CreateTaskResult
243+
if (params.task) {
244+
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
245+
if (!taskValidationResult.success) {
258246
const errorMessage =
259-
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
260-
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
247+
taskValidationResult.error instanceof Error
248+
? taskValidationResult.error.message
249+
: String(taskValidationResult.error);
250+
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
261251
}
252+
return taskValidationResult.data;
253+
}
262254

263-
return validationResult.data;
264-
};
265-
266-
// Install the wrapped handler
267-
return super.setRequestHandler(method, wrappedHandler);
268-
}
255+
// For non-task requests, validate against CallToolResultSchema
256+
const validationResult = parseSchema(CallToolResultSchema, result);
257+
if (!validationResult.success) {
258+
const errorMessage =
259+
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
260+
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
261+
}
269262

270-
// Other handlers use default behavior
271-
return super.setRequestHandler(method, handler);
263+
return validationResult.data;
264+
};
272265
}
273266

274-
protected assertCapabilityForMethod(method: RequestMethod): void {
267+
protected assertCapabilityForMethod(method: string): void {
275268
switch (method) {
276269
case 'sampling/createMessage': {
277270
if (!this._clientCapabilities?.sampling) {
@@ -304,7 +297,7 @@ export class Server extends Protocol<ServerContext> {
304297
}
305298
}
306299

307-
protected assertNotificationCapability(method: NotificationMethod): void {
300+
protected assertNotificationCapability(method: string): void {
308301
switch (method) {
309302
case 'notifications/message': {
310303
if (!this._capabilities.logging) {

0 commit comments

Comments
 (0)