diff --git a/.changeset/three-arg-custom-methods.md b/.changeset/three-arg-custom-methods.md new file mode 100644 index 000000000..e4f08b27d --- /dev/null +++ b/.changeset/three-arg-custom-methods.md @@ -0,0 +1,6 @@ +--- +'@modelcontextprotocol/client': minor +'@modelcontextprotocol/server': minor +--- + +`setRequestHandler`/`setNotificationHandler` gain a 3-arg `(method: string, paramsSchema, handler)` form for custom (non-spec) methods. `paramsSchema` is any Standard Schema (Zod, Valibot, ArkType, etc.); the handler receives validated `params`. diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 2c753f4ab..c5f66c8b8 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -377,14 +377,16 @@ Schema to method string mapping: Request/notification params remain fully typed. Remove unused schema imports after migration. -**Custom (non-standard) methods** — vendor extensions or sub-protocols whose method strings are not in the MCP spec — work on `Client`/`Server` directly using the same v1 Zod-schema form: - -| Form | Notes | -| ------------------------------------------------------------ | --------------------------------------------------------------------- | -| `setRequestHandler(CustomReqSchema, (req, ctx) => ...)` | unchanged | -| `setNotificationHandler(CustomNotifSchema, n => ...)` | unchanged | -| `this.request({ method: 'vendor/x', params }, ResultSchema)` | unchanged | -| `this.notification({ method: 'vendor/x', params })` | unchanged | +**Custom (non-standard) methods** — vendor extensions or sub-protocols whose method strings are not in the MCP spec — work on `Client`/`Server` directly. The v1 Zod-schema forms continue to work; the three-arg `(method, paramsSchema, handler)` form is the alternative: + +| v1 (still supported) | v2 alternative | +| ------------------------------------------------------------ | ------------------------------------------------------------------------ | +| `setRequestHandler(CustomReqSchema, (req, ctx) => ...)` | `setRequestHandler('vendor/method', ParamsSchema, (params, ctx) => ...)` | +| `setNotificationHandler(CustomNotifSchema, n => ...)` | `setNotificationHandler('vendor/method', ParamsSchema, params => ...)` | +| `this.request({ method: 'vendor/x', params }, ResultSchema)` | unchanged | +| `this.notification({ method: 'vendor/x', params })` | unchanged | + +For the three-arg form, the v1 schema's `.shape.params` becomes the `ParamsSchema` argument and the `method: z.literal('...')` value becomes the string argument. ## 10. Request Handler Context Types diff --git a/docs/migration.md b/docs/migration.md index 7be893290..091377ad2 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -384,15 +384,19 @@ Common method string replacements: ### Custom (non-standard) protocol methods -Vendor-specific methods are registered directly on `Client` or `Server` using the same Zod-schema form as v1: `setRequestHandler(zodSchemaWithMethodLiteral, handler)`. `request({ method, params }, ResultSchema)` and `notification({ method, params })` are unchanged from v1. +Vendor-specific methods are registered directly on `Client` or `Server`. The v1 form `setRequestHandler(zodSchemaWithMethodLiteral, handler)` continues to work; the three-arg `(methodString, paramsSchema, handler)` form is the v2 alternative. `request({ method, params }, ResultSchema)` and `notification({ method, params })` are unchanged from v1. ```typescript import { Server } from '@modelcontextprotocol/server'; const server = new Server({ name: 'app', version: '1.0.0' }, { capabilities: {} }); +// v1 form (still supported): server.setRequestHandler(SearchRequestSchema, req => ({ hits: [req.params.query] })); +// v2 alternative — pass method string + params schema; handler receives validated params: +server.setRequestHandler('acme/search', SearchParams, params => ({ hits: [params.query] })); + // Calling from a Client — unchanged from v1: const result = await client.request({ method: 'acme/search', params: { query: 'x' } }, SearchResult); ``` diff --git a/examples/client/src/customMethodExample.ts b/examples/client/src/customMethodExample.ts index d0ce0e994..08bd8f498 100644 --- a/examples/client/src/customMethodExample.ts +++ b/examples/client/src/customMethodExample.ts @@ -3,8 +3,11 @@ * Calling vendor-specific (non-spec) JSON-RPC methods from a `Client`. * * - Send a custom request: `client.request({ method, params }, resultSchema)` - * - Send a custom notification: `client.notification({ method, params })` - * - Receive a custom notification: `client.setNotificationHandler(ZodSchemaWithMethodLiteral, handler)` + * - Send a custom notification: `client.notification({ method, params })` (unchanged from v1) + * - Receive a custom notification: 3-arg `client.setNotificationHandler(method, paramsSchema, handler)` + * + * These overloads are on `Client` and `Server` directly — you do NOT need a raw + * `Protocol` instance for custom methods. * * Pair with the server in examples/server/src/customMethodExample.ts. */ @@ -13,16 +16,12 @@ import { Client, StdioClientTransport } from '@modelcontextprotocol/client'; import { z } from 'zod'; const SearchResult = z.object({ hits: z.array(z.string()) }); - -const ProgressNotification = z.object({ - method: z.literal('acme/searchProgress'), - params: z.object({ stage: z.string(), pct: z.number() }) -}); +const ProgressParams = z.object({ stage: z.string(), pct: z.number() }); const client = new Client({ name: 'custom-method-client', version: '1.0.0' }, { capabilities: {} }); -client.setNotificationHandler(ProgressNotification, n => { - console.log(`[client] progress: ${n.params.stage} ${n.params.pct}%`); +client.setNotificationHandler('acme/searchProgress', ProgressParams, p => { + console.log(`[client] progress: ${p.stage} ${p.pct}%`); }); await client.connect(new StdioClientTransport({ command: 'npx', args: ['tsx', '../server/src/customMethodExample.ts'] })); diff --git a/examples/server/src/customMethodExample.ts b/examples/server/src/customMethodExample.ts index b8b2e222f..3fc6099a7 100644 --- a/examples/server/src/customMethodExample.ts +++ b/examples/server/src/customMethodExample.ts @@ -2,9 +2,10 @@ /** * Registering vendor-specific (non-spec) JSON-RPC methods on a `Server`. * - * Custom methods use the Zod-schema form of `setRequestHandler` / `setNotificationHandler`: - * pass a Zod object schema whose `method` field is `z.literal('')`. The same overload - * is available on `Client` (for server→client custom methods). + * Custom methods use the 3-arg form of `setRequestHandler` / `setNotificationHandler`: + * pass the method string, a params schema, and the handler. The same overload is + * available on `Client` (for server→client custom methods) — you do NOT need a raw + * `Protocol` instance for this. * * To call these from the client side, use: * await client.request({ method: 'acme/search', params: { query: 'widgets' } }, SearchResult) @@ -15,28 +16,21 @@ import { Server, StdioServerTransport } from '@modelcontextprotocol/server'; import { z } from 'zod'; -const SearchRequest = z.object({ - method: z.literal('acme/search'), - params: z.object({ query: z.string() }) -}); - -const TickNotification = z.object({ - method: z.literal('acme/tick'), - params: z.object({ n: z.number() }) -}); +const SearchParams = z.object({ query: z.string() }); +const TickParams = z.object({ n: z.number() }); const server = new Server({ name: 'custom-method-server', version: '1.0.0' }, { capabilities: {} }); -server.setRequestHandler(SearchRequest, async (request, ctx) => { - console.error('[server] acme/search query=' + request.params.query); +server.setRequestHandler('acme/search', SearchParams, async (params, ctx) => { + console.error('[server] acme/search query=' + params.query); await ctx.mcpReq.notify({ method: 'acme/searchProgress', params: { stage: 'start', pct: 0 } }); - const hits = [request.params.query, request.params.query + '-result']; + const hits = [params.query, params.query + '-result']; await ctx.mcpReq.notify({ method: 'acme/searchProgress', params: { stage: 'done', pct: 100 } }); return { hits }; }); -server.setNotificationHandler(TickNotification, n => { - console.error('[server] acme/tick n=' + n.params.n); +server.setNotificationHandler('acme/tick', TickParams, p => { + console.error('[server] acme/tick n=' + p.n); }); await server.connect(new StdioServerTransport()); diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index e1ae948d7..f22a5ee4f 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -32,6 +32,7 @@ import type { Result, ResultTypeMap, ServerCapabilities, + StandardSchemaV1, SubscribeRequest, TaskManagerOptions, Tool, @@ -343,22 +344,34 @@ export class Client extends Protocol { method: M, handler: (request: RequestTypeMap[M], ctx: ClientContext) => ResultTypeMap[M] | Promise ): void; - /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + public override setRequestHandler

( + method: string, + paramsSchema: P, + handler: (params: StandardSchemaV1.InferOutput

, ctx: ClientContext) => Result | Promise + ): void; + /** For spec methods the method-string form is more concise; this overload is a supported call form for non-spec methods (alongside the three-arg `(method, paramsSchema, handler)` form) or when you want full-envelope validation. */ public override setRequestHandler( requestSchema: T, handler: (request: ReturnType, ctx: ClientContext) => Result | Promise ): void; - public override setRequestHandler(methodOrSchema: string | ZodLikeRequestSchema, schemaHandler: unknown): void { + public override setRequestHandler( + methodOrSchema: string | ZodLikeRequestSchema, + schemaOrHandler: unknown, + maybeHandler?: (params: unknown, ctx: ClientContext) => unknown + ): void { let method: string; let handler: (request: Request, ctx: ClientContext) => ClientResult | Promise; if (isZodLikeSchema(methodOrSchema)) { const schema = methodOrSchema; - const userHandler = schemaHandler as (request: unknown, ctx: ClientContext) => Result | Promise; + const userHandler = schemaOrHandler as (request: unknown, ctx: ClientContext) => Result | Promise; method = extractMethodLiteral(schema); handler = (req, ctx) => userHandler(schema.parse(req), ctx); + } else if (maybeHandler === undefined) { + method = methodOrSchema; + handler = schemaOrHandler as (request: Request, ctx: ClientContext) => ClientResult | Promise; } else { method = methodOrSchema; - handler = schemaHandler as (request: Request, ctx: ClientContext) => ClientResult | Promise; + handler = this._wrapParamsSchemaHandler(method, schemaOrHandler as StandardSchemaV1, maybeHandler); } if (method === 'elicitation/create') { const wrappedHandler = async (request: Request, ctx: ClientContext): Promise => { diff --git a/packages/core/src/exports/public/index.ts b/packages/core/src/exports/public/index.ts index 227cd3e24..fd2cada0c 100644 --- a/packages/core/src/exports/public/index.ts +++ b/packages/core/src/exports/public/index.ts @@ -137,7 +137,7 @@ export { isTerminal } from '../../experimental/tasks/interfaces.js'; export { InMemoryTaskMessageQueue, InMemoryTaskStore } from '../../experimental/tasks/stores/inMemory.js'; // Validator types and classes -export type { StandardSchemaWithJSON } from '../../util/standardSchema.js'; +export type { StandardSchemaV1, StandardSchemaWithJSON } from '../../util/standardSchema.js'; export { AjvJsonSchemaValidator } from '../../validators/ajvProvider.js'; export type { CfWorkerSchemaDraft } from '../../validators/cfWorkerProvider.js'; // fromJsonSchema is intentionally NOT exported here — the server and client packages diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 2508cb51e..8e7abbb12 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -48,6 +48,8 @@ import type { ZodLikeRequestSchema } from '../util/compatSchema.js'; import { extractMethodLiteral, isZodLikeSchema } from '../util/compatSchema.js'; import type { AnySchema, SchemaOutput } from '../util/schema.js'; import { parseSchema } from '../util/schema.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import { parseStandardSchema } from '../util/standardSchema.js'; import type { TaskContext, TaskManagerHost, TaskManagerOptions, TaskRequestOptions } from './taskManager.js'; import { NullTaskManager, TaskManager } from './taskManager.js'; import type { Transport, TransportSendOptions } from './transport.js'; @@ -1033,9 +1035,14 @@ export abstract class Protocol { * method. Replaces any previous handler for the same method. * * Call forms: - * - **Spec method** — `setRequestHandler('tools/call', (request, ctx) => …)`. + * - **Spec method, two args** — `setRequestHandler('tools/call', (request, ctx) => …)`. * The full `RequestTypeMap[M]` request object is validated by the SDK and passed to the * handler. This is the form `Client`/`Server` use and override. + * - **Three args** — `setRequestHandler('vendor/custom', paramsSchema, (params, ctx) => …)`. + * Any method string; the supplied schema validates incoming `params`. Absent or undefined + * `params` are normalized to `{}` (after stripping `_meta`) before validation, so for + * no-params methods use `z.object({})`. `paramsSchema` may be any Standard Schema (Zod, + * Valibot, ArkType, etc.). * - **Zod schema** — `setRequestHandler(RequestZodSchema, (request, ctx) => …)`. The method * name is read from the schema's `method` literal; the handler receives the parsed request. */ @@ -1043,22 +1050,62 @@ export abstract class Protocol { method: M, handler: (request: RequestTypeMap[M], ctx: ContextT) => Result | Promise ): void; - /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + setRequestHandler

( + method: string, + paramsSchema: P, + handler: (params: StandardSchemaV1.InferOutput

, ctx: ContextT) => Result | Promise + ): void; + /** For spec methods the method-string form is more concise; this overload is a supported call form for non-spec methods (alongside the three-arg `(method, paramsSchema, handler)` form) or when you want full-envelope validation. */ setRequestHandler( requestSchema: T, handler: (request: ReturnType, ctx: ContextT) => Result | Promise ): void; - setRequestHandler(method: string | ZodLikeRequestSchema, handler: (request: Request, ctx: ContextT) => Result | Promise): void { + setRequestHandler( + method: string | ZodLikeRequestSchema, + schemaOrHandler: StandardSchemaV1 | ((request: Request, ctx: ContextT) => Result | Promise), + maybeHandler?: (params: unknown, ctx: ContextT) => unknown + ): void { if (isZodLikeSchema(method)) { const requestSchema = method; const methodStr = extractMethodLiteral(requestSchema); this.assertRequestHandlerCapability(methodStr); this._requestHandlers.set(methodStr, (request, ctx) => - Promise.resolve((handler as (req: unknown, ctx: ContextT) => Result | Promise)(requestSchema.parse(request), ctx)) + Promise.resolve( + (schemaOrHandler as (req: unknown, ctx: ContextT) => Result | Promise)(requestSchema.parse(request), ctx) + ) ); return; } - this._setRequestHandlerByMethod(method, handler); + if (maybeHandler === undefined) { + return this._setRequestHandlerByMethod( + method, + schemaOrHandler as (request: Request, ctx: ContextT) => Result | Promise + ); + } + + this._setRequestHandlerByMethod(method, this._wrapParamsSchemaHandler(method, schemaOrHandler as StandardSchemaV1, maybeHandler)); + } + + /** + * Builds a request handler from a `paramsSchema` + params-only user handler. Strips `_meta`, + * validates `params` against the schema, and invokes the user handler with the parsed params. + * Shared by {@linkcode setRequestHandler}'s 3-arg dispatch and `Client`/`Server` overrides + * so that per-method wrapping can be applied uniformly to the normalized handler. + */ + protected _wrapParamsSchemaHandler( + method: string, + paramsSchema: StandardSchemaV1, + userHandler: (params: unknown, ctx: ContextT) => unknown + ): (request: Request, ctx: ContextT) => Promise { + return async (request, ctx) => { + const { _meta, ...userParams } = (request.params ?? {}) as Record; + void _meta; + const parsed = await parseStandardSchema(paramsSchema, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`); + } + return userHandler(parsed.data, ctx) as Result; + }; } /** @@ -1094,31 +1141,55 @@ export abstract class Protocol { * Registers a handler to invoke when this protocol object receives a notification with the * given method. Replaces any previous handler for the same method. * - * Mirrors {@linkcode setRequestHandler}: a spec-method form (handler receives the full - * notification object) and a Zod-schema form (method read from the schema's `method` literal). + * Mirrors {@linkcode setRequestHandler}: a two-arg spec-method form (handler receives the full + * notification object), a three-arg form with a `paramsSchema` (handler receives validated + * `params`), and a Zod-schema form (method read from the schema's `method` literal). */ setNotificationHandler( method: M, handler: (notification: NotificationTypeMap[M]) => void | Promise ): void; - /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + setNotificationHandler

( + method: string, + paramsSchema: P, + handler: (params: StandardSchemaV1.InferOutput

) => void | Promise + ): void; + /** For spec methods the method-string form is more concise; this overload is a supported call form for non-spec methods (alongside the three-arg `(method, paramsSchema, handler)` form) or when you want full-envelope validation. */ setNotificationHandler( notificationSchema: T, handler: (notification: ReturnType) => void | Promise ): void; - setNotificationHandler(method: string | ZodLikeRequestSchema, handler: (notification: Notification) => void | Promise): void { + setNotificationHandler( + method: string | ZodLikeRequestSchema, + schemaOrHandler: StandardSchemaV1 | ((notification: Notification) => void | Promise), + maybeHandler?: (params: unknown) => void | Promise + ): void { if (isZodLikeSchema(method)) { const notificationSchema = method; const methodStr = extractMethodLiteral(notificationSchema); - this._notificationHandlers.set(methodStr, n => - Promise.resolve((handler as (n: unknown) => void | Promise)(notificationSchema.parse(n))) - ); + const handler = schemaOrHandler as (notification: unknown) => void | Promise; + this._notificationHandlers.set(methodStr, n => Promise.resolve(handler(notificationSchema.parse(n)))); + return; + } + if (maybeHandler === undefined) { + const handler = schemaOrHandler as (notification: Notification) => void | Promise; + const schema = getNotificationSchema(method as NotificationMethod); + this._notificationHandlers.set(method, notification => { + const parsed = schema ? schema.parse(notification) : notification; + return Promise.resolve(handler(parsed)); + }); return; } - const schema = getNotificationSchema(method as NotificationMethod); - this._notificationHandlers.set(method, notification => { - const parsed = schema ? schema.parse(notification) : notification; - return Promise.resolve(handler(parsed)); + + const paramsSchema = schemaOrHandler as StandardSchemaV1; + this._notificationHandlers.set(method, async notification => { + const { _meta, ...userParams } = (notification.params ?? {}) as Record; + void _meta; + const parsed = await parseStandardSchema(paramsSchema, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`); + } + return maybeHandler(parsed.data); }); } diff --git a/packages/core/src/util/standardSchema.ts b/packages/core/src/util/standardSchema.ts index 9817dc39a..4af0dbff6 100644 --- a/packages/core/src/util/standardSchema.ts +++ b/packages/core/src/util/standardSchema.ts @@ -169,15 +169,34 @@ function formatIssue(issue: StandardSchemaV1.Issue): string { return `${path}: ${issue.message}`; } -export async function validateStandardSchema( +export async function validateStandardSchema( schema: T, data: unknown -): Promise>> { +): Promise>> { const result = await schema['~standard'].validate(data); if (result.issues && result.issues.length > 0) { return { success: false, error: result.issues.map(i => formatIssue(i)).join(', ') }; } - return { success: true, data: (result as StandardSchemaV1.SuccessResult).value as StandardSchemaWithJSON.InferOutput }; + return { success: true, data: (result as StandardSchemaV1.SuccessResult).value as StandardSchemaV1.InferOutput }; +} + +/** + * Parses data against any Standard Schema. Async because Standard Schema's `validate` may return + * a Promise. The error is wrapped as an `Error` whose `.message` is the formatted issues string + * from {@linkcode validateStandardSchema}, so callers can interpolate it directly. + * + * Use this for user-supplied schemas (e.g. the 3-arg `setRequestHandler(method, paramsSchema, h)` + * form). For internal SDK Zod schemas, prefer the synchronous `parseSchema` in `./schema.js`. + */ +export async function parseStandardSchema( + schema: T, + data: unknown +): Promise<{ success: true; data: StandardSchemaV1.InferOutput } | { success: false; error: Error }> { + const result = await validateStandardSchema(schema, data); + if (result.success) { + return result; + } + return { success: false, error: new Error(result.error) }; } // Prompt argument extraction diff --git a/packages/core/test/shared/threeArgHandlers.test.ts b/packages/core/test/shared/threeArgHandlers.test.ts new file mode 100644 index 000000000..dfb8a2015 --- /dev/null +++ b/packages/core/test/shared/threeArgHandlers.test.ts @@ -0,0 +1,102 @@ +import { describe, expect, it } from 'vitest'; +import { z } from 'zod'; + +import type { BaseContext } from '../../src/shared/protocol.js'; +import { Protocol } from '../../src/shared/protocol.js'; +import type { StandardSchemaV1 } from '../../src/util/standardSchema.js'; +import { InMemoryTransport } from '../../src/util/inMemory.js'; + +class TestProtocol extends Protocol { + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } +} + +async function makePair() { + const [t1, t2] = InMemoryTransport.createLinkedPair(); + const a = new TestProtocol(); + const b = new TestProtocol(); + await a.connect(t1); + await b.connect(t2); + return { a, b }; +} + +describe('setRequestHandler — three-arg paramsSchema form', () => { + it('round-trips a custom request with validated params', async () => { + const { a, b } = await makePair(); + b.setRequestHandler('acme/echo', z.object({ msg: z.string() }), params => ({ reply: params.msg.toUpperCase() })); + const result = await a.request({ method: 'acme/echo', params: { msg: 'hi' } }, z.object({ reply: z.string() })); + expect(result).toEqual({ reply: 'HI' }); + }); + + it('rejects invalid params with InvalidParams', async () => { + const { a, b } = await makePair(); + b.setRequestHandler('acme/echo', z.object({ msg: z.string() }), p => ({ reply: p.msg })); + await expect(a.request({ method: 'acme/echo', params: { msg: 42 } }, z.object({ reply: z.string() }))).rejects.toThrow( + /Invalid params for acme\/echo/ + ); + }); + + it('normalizes absent params to {}', async () => { + const { a, b } = await makePair(); + let seen: unknown; + b.setRequestHandler('acme/noop', z.object({}).strict(), p => { + seen = p; + return {}; + }); + await a.request({ method: 'acme/noop' }, z.object({})); + expect(seen).toEqual({}); + }); + + it('strips _meta before validating against paramsSchema', async () => { + const { a, b } = await makePair(); + let seen: unknown; + b.setRequestHandler('acme/noop', z.object({}).strict(), p => { + seen = p; + return {}; + }); + await a.request({ method: 'acme/noop', params: { _meta: { trace: 'x' } } }, z.object({})); + expect(seen).toEqual({}); + }); +}); + +describe('setNotificationHandler — three-arg paramsSchema form', () => { + it('receives a custom notification', async () => { + const { a, b } = await makePair(); + const received: unknown[] = []; + b.setNotificationHandler('acme/tick', z.object({ n: z.number() }), p => { + received.push(p); + }); + await a.notification({ method: 'acme/tick', params: { n: 1 } }); + await a.notification({ method: 'acme/tick', params: { n: 2 } }); + await new Promise(r => setTimeout(r, 0)); + expect(received).toEqual([{ n: 1 }, { n: 2 }]); + }); +}); + +describe('non-Zod StandardSchemaV1', () => { + function makeStandardSchema(check: (v: unknown) => v is T): StandardSchemaV1 { + return { + '~standard': { + version: 1 as const, + vendor: 'test', + types: undefined as unknown as { input: T; output: T }, + validate: (v: unknown) => (check(v) ? { value: v } : { issues: [{ message: 'invalid', path: [] }] }) + } + }; + } + + it('accepts a hand-rolled StandardSchemaV1 in 3-arg setRequestHandler', async () => { + const { a, b } = await makePair(); + type Params = { n: number }; + const Params = makeStandardSchema((v): v is Params => typeof (v as Params)?.n === 'number'); + b.setRequestHandler('acme/double', Params, (p: Params) => ({ doubled: p.n * 2 })); + const r = await a.request({ method: 'acme/double', params: { n: 21 } }, z.object({ doubled: z.number() })); + expect(r.doubled).toBe(42); + }); +}); diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index a4690b7f4..fd7b93e88 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -31,6 +31,7 @@ import type { ServerCapabilities, ServerContext, ServerResult, + StandardSchemaV1, TaskManagerOptions, ToolResultContent, ToolUseContent, @@ -231,22 +232,34 @@ export class Server extends Protocol { method: M, handler: (request: RequestTypeMap[M], ctx: ServerContext) => ResultTypeMap[M] | Promise ): void; - /** For spec methods the method-string form is more concise; this overload is the supported call form for non-spec methods or when you want full-envelope validation. */ + public override setRequestHandler

( + method: string, + paramsSchema: P, + handler: (params: StandardSchemaV1.InferOutput

, ctx: ServerContext) => Result | Promise + ): void; + /** For spec methods the method-string form is more concise; this overload is a supported call form for non-spec methods (alongside the three-arg `(method, paramsSchema, handler)` form) or when you want full-envelope validation. */ public override setRequestHandler( requestSchema: T, handler: (request: ReturnType, ctx: ServerContext) => Result | Promise ): void; - public override setRequestHandler(methodOrSchema: string | ZodLikeRequestSchema, schemaHandler: unknown): void { + public override setRequestHandler( + methodOrSchema: string | ZodLikeRequestSchema, + schemaOrHandler: unknown, + maybeHandler?: (params: unknown, ctx: ServerContext) => unknown + ): void { let method: string; let handler: (request: Request, ctx: ServerContext) => ServerResult | Promise; if (isZodLikeSchema(methodOrSchema)) { const schema = methodOrSchema; - const userHandler = schemaHandler as (request: unknown, ctx: ServerContext) => Result | Promise; + const userHandler = schemaOrHandler as (request: unknown, ctx: ServerContext) => Result | Promise; method = extractMethodLiteral(schema); handler = (req, ctx) => userHandler(schema.parse(req), ctx); + } else if (maybeHandler === undefined) { + method = methodOrSchema; + handler = schemaOrHandler as (request: Request, ctx: ServerContext) => ServerResult | Promise; } else { method = methodOrSchema; - handler = schemaHandler as (request: Request, ctx: ServerContext) => ServerResult | Promise; + handler = this._wrapParamsSchemaHandler(method, schemaOrHandler as StandardSchemaV1, maybeHandler); } if (method === 'tools/call') { const wrappedHandler = async (request: Request, ctx: ServerContext): Promise => {