Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/wraphandler-hook.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
'@modelcontextprotocol/core': patch
'@modelcontextprotocol/client': patch
'@modelcontextprotocol/server': patch
---

refactor: subclasses override `_wrapHandler` hook instead of redeclaring `setRequestHandler`.
78 changes: 35 additions & 43 deletions packages/client/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import type {
ClientContext,
ClientNotification,
ClientRequest,
ClientResult,
CompleteRequest,
GetPromptRequest,
Implementation,
JSONRPCRequest,
JsonSchemaType,
JsonSchemaValidator,
jsonSchemaValidator,
Expand All @@ -26,8 +26,7 @@ import type {
ReadResourceRequest,
RequestMethod,
RequestOptions,
RequestTypeMap,
ResultTypeMap,
Result,
ServerCapabilities,
SubscribeRequest,
TaskManagerOptions,
Expand Down Expand Up @@ -200,6 +199,28 @@ export type ClientOptions = ProtocolOptions & {
*
* The client will automatically begin the initialization flow with the server when {@linkcode connect} is called.
*
* To handle server-initiated requests (sampling, elicitation, roots), call {@linkcode setRequestHandler}.
* The client must declare the corresponding capability for the handler to be accepted. For
* `sampling/createMessage` and `elicitation/create`, the handler is automatically wrapped with
* schema validation for both the incoming request and the returned result.
*
* @example Handling a sampling request
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
* client.setRequestHandler('sampling/createMessage', async request => {
* const lastMessage = request.params.messages.at(-1);
* console.log('Sampling request:', lastMessage);
*
* // In production, send messages to your LLM here
* return {
* model: 'my-model',
* role: 'assistant' as const,
* content: {
* type: 'text' as const,
* text: 'Response from the model'
* }
* };
* });
* ```
*/
export class Client extends Protocol<ClientContext> {
private _serverCapabilities?: ServerCapabilities;
Expand Down Expand Up @@ -308,37 +329,15 @@ export class Client extends Protocol<ClientContext> {
}

/**
* Registers a handler for server-initiated requests (sampling, elicitation, roots).
* The client must declare the corresponding capability for the handler to be accepted.
* Replaces any previously registered handler for the same method.
*
* For `sampling/createMessage` and `elicitation/create`, the handler is automatically
* wrapped with schema validation for both the incoming request and the returned result.
*
* @example Handling a sampling request
* ```ts source="./client.examples.ts#Client_setRequestHandler_sampling"
* client.setRequestHandler('sampling/createMessage', async request => {
* const lastMessage = request.params.messages.at(-1);
* console.log('Sampling request:', lastMessage);
*
* // In production, send messages to your LLM here
* return {
* model: 'my-model',
* role: 'assistant' as const,
* content: {
* type: 'text' as const,
* text: 'Response from the model'
* }
* };
* });
* ```
* Enforces client-side validation for `elicitation/create` and `sampling/createMessage`
* regardless of how the handler was registered.
*/
public override setRequestHandler<M extends RequestMethod>(
method: M,
handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
): void {
protected override _wrapHandler(
method: string,
handler: (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result>
): (request: JSONRPCRequest, ctx: ClientContext) => Promise<Result> {
if (method === 'elicitation/create') {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
return async (request, ctx) => {
const validatedRequest = parseSchema(ElicitRequestSchema, request);
if (!validatedRequest.success) {
// Type guard: if success is false, error is guaranteed to exist
Expand All @@ -359,7 +358,7 @@ export class Client extends Protocol<ClientContext> {
throw new ProtocolError(ProtocolErrorCode.InvalidParams, 'Client does not support URL-mode elicitation requests');
}

const result = await Promise.resolve(handler(request, ctx));
const result = await handler(request, ctx);

// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
Expand Down Expand Up @@ -402,13 +401,10 @@ export class Client extends Protocol<ClientContext> {

return validatedResult;
};

// Install the wrapped handler
return super.setRequestHandler(method, wrappedHandler);
}

if (method === 'sampling/createMessage') {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
return async (request, ctx) => {
const validatedRequest = parseSchema(CreateMessageRequestSchema, request);
if (!validatedRequest.success) {
const errorMessage =
Expand All @@ -418,7 +414,7 @@ export class Client extends Protocol<ClientContext> {

const { params } = validatedRequest.data;

const result = await Promise.resolve(handler(request, ctx));
const result = await handler(request, ctx);

// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
Expand All @@ -445,13 +441,9 @@ export class Client extends Protocol<ClientContext> {

return validationResult.data;
};

// Install the wrapped handler
return super.setRequestHandler(method, wrappedHandler);
}

// Other handlers use default behavior
return super.setRequestHandler(method, handler);
return handler;
}

protected assertCapability(capability: keyof ServerCapabilities, method: string): void {
Expand Down
21 changes: 18 additions & 3 deletions packages/core/src/shared/protocol.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1007,15 +1007,30 @@ export abstract class Protocol<ContextT extends BaseContext> {
*/
setRequestHandler<M extends RequestMethod>(
method: M,
handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise<Result>
handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
): void {
this.assertRequestHandlerCapability(method);
const schema = getRequestSchema(method);

this._requestHandlers.set(method, (request, ctx) => {
const stored = (request: JSONRPCRequest, ctx: ContextT): Promise<Result> => {
const parsed = schema.parse(request) as RequestTypeMap[M];
return Promise.resolve(handler(parsed, ctx));
});
};
this._requestHandlers.set(method, this._wrapHandler(method, stored));
Comment thread
claude[bot] marked this conversation as resolved.
}

/**
* Hook for subclasses to wrap a registered request handler with role-specific
* validation or behavior (e.g. `Server` validates `tools/call` results, `Client`
* validates `elicitation/create` mode and result). The default implementation is identity.
*
* Subclasses overriding this hook avoid redeclaring `setRequestHandler` and its JSDoc.
*/
protected _wrapHandler(
_method: string,
handler: (request: JSONRPCRequest, ctx: ContextT) => Promise<Result>
): (request: JSONRPCRequest, ctx: ContextT) => Promise<Result> {
return handler;
}

/**
Expand Down
35 changes: 35 additions & 0 deletions packages/core/test/shared/wrapHandler.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { describe, expect, it } from 'vitest';

import { Protocol } from '../../src/shared/protocol.js';
import type { BaseContext, JSONRPCRequest, Result } from '../../src/exports/public/index.js';

class TestProtocol extends Protocol<BaseContext> {
protected buildContext(ctx: BaseContext): BaseContext {
return ctx;
}
protected assertCapabilityForMethod(): void {}
protected assertNotificationCapability(): void {}
protected assertRequestHandlerCapability(): void {}
protected assertTaskCapability(): void {}
protected assertTaskHandlerCapability(): void {}
}

describe('Protocol._wrapHandler', () => {
it('routes setRequestHandler registration through _wrapHandler', () => {
const seen: string[] = [];
class SpyProtocol extends TestProtocol {
protected override _wrapHandler(
method: string,
handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result>
): (request: JSONRPCRequest, ctx: BaseContext) => Promise<Result> {
seen.push(method);
return handler;
}
}
const p = new SpyProtocol();
seen.length = 0;
p.setRequestHandler('tools/list', () => ({ tools: [] }));
p.setRequestHandler('resources/list', () => ({ resources: [] }));
expect(seen).toEqual(['tools/list', 'resources/list']);
});
});
83 changes: 39 additions & 44 deletions packages/server/src/server/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type {
Implementation,
InitializeRequest,
InitializeResult,
JSONRPCRequest,
JsonSchemaType,
jsonSchemaValidator,
ListRootsRequest,
Expand All @@ -23,12 +24,10 @@ import type {
ProtocolOptions,
RequestMethod,
RequestOptions,
RequestTypeMap,
ResourceUpdatedNotification,
ResultTypeMap,
Result,
ServerCapabilities,
ServerContext,
ServerResult,
TaskManagerOptions,
ToolResultContent,
ToolUseContent
Expand Down Expand Up @@ -220,55 +219,51 @@ export class Server extends Protocol<ServerContext> {
}

/**
* Override request handler registration to enforce server-side validation for `tools/call`.
* Enforces server-side validation for `tools/call` results regardless of how the
* handler was registered.
*/
public override setRequestHandler<M extends RequestMethod>(
method: M,
handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise<ResultTypeMap[M]>
): void {
if (method === 'tools/call') {
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise<ServerResult> => {
const validatedRequest = parseSchema(CallToolRequestSchema, request);
if (!validatedRequest.success) {
const errorMessage =
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
}

const { params } = validatedRequest.data;
protected override _wrapHandler(
method: string,
handler: (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result>
): (request: JSONRPCRequest, ctx: ServerContext) => Promise<Result> {
if (method !== 'tools/call') {
return handler;
}
return async (request, ctx) => {
const validatedRequest = parseSchema(CallToolRequestSchema, request);
if (!validatedRequest.success) {
const errorMessage =
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call request: ${errorMessage}`);
}

const result = await Promise.resolve(handler(request, ctx));
const { params } = validatedRequest.data;

// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
if (!taskValidationResult.success) {
const errorMessage =
taskValidationResult.error instanceof Error
? taskValidationResult.error.message
: String(taskValidationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
}
return taskValidationResult.data;
}
const result = await handler(request, ctx);

// For non-task requests, validate against CallToolResultSchema
const validationResult = parseSchema(CallToolResultSchema, result);
if (!validationResult.success) {
// When task creation is requested, validate and return CreateTaskResult
if (params.task) {
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
if (!taskValidationResult.success) {
const errorMessage =
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
taskValidationResult.error instanceof Error
? taskValidationResult.error.message
: String(taskValidationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid task creation result: ${errorMessage}`);
}
return taskValidationResult.data;
}

return validationResult.data;
};

// Install the wrapped handler
return super.setRequestHandler(method, wrappedHandler);
}
// For non-task requests, validate against CallToolResultSchema
const validationResult = parseSchema(CallToolResultSchema, result);
if (!validationResult.success) {
const errorMessage =
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid tools/call result: ${errorMessage}`);
}

// Other handlers use default behavior
return super.setRequestHandler(method, handler);
return validationResult.data;
};
}

protected assertCapabilityForMethod(method: RequestMethod): void {
Expand Down
Loading