diff --git a/.changeset/custom-methods-minimal.md b/.changeset/custom-methods-minimal.md new file mode 100644 index 000000000..f722f4504 --- /dev/null +++ b/.changeset/custom-methods-minimal.md @@ -0,0 +1,9 @@ +--- +'@modelcontextprotocol/core': minor +'@modelcontextprotocol/client': minor +'@modelcontextprotocol/server': minor +--- + +Add custom (non-spec) method support: a 3-arg `setRequestHandler(method, schemas, handler)` / `setNotificationHandler(method, schemas, handler)` form for vendor-prefixed methods, and a `request(req, resultSchema)` overload (also on `ctx.mcpReq.send`) for typed custom-method results. Spec-method calls are unchanged. + +Response result-schema validation failure now rejects with `SdkError(InvalidResult)` instead of a raw `ZodError`. Adds `SdkErrorCode.InvalidResult`. diff --git a/docs/migration-SKILL.md b/docs/migration-SKILL.md index 8dccc81f4..7f27356ee 100644 --- a/docs/migration-SKILL.md +++ b/docs/migration-SKILL.md @@ -120,6 +120,7 @@ Two error classes now exist: | 403 after upscoping | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpForbidden` | | Unexpected content type | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpUnexpectedContent` | | Session termination failed | `StreamableHTTPError` | `SdkError` with `SdkErrorCode.ClientHttpFailedToTerminateSession` | +| Response result fails schema | `ZodError` (raw) | `SdkError` with `SdkErrorCode.InvalidResult` | New `SdkErrorCode` enum values: @@ -130,6 +131,7 @@ New `SdkErrorCode` enum values: - `SdkErrorCode.RequestTimeout` = `'REQUEST_TIMEOUT'` - `SdkErrorCode.ConnectionClosed` = `'CONNECTION_CLOSED'` - `SdkErrorCode.SendFailed` = `'SEND_FAILED'` +- `SdkErrorCode.InvalidResult` = `'INVALID_RESULT'` - `SdkErrorCode.ClientHttpNotImplemented` = `'CLIENT_HTTP_NOT_IMPLEMENTED'` - `SdkErrorCode.ClientHttpAuthentication` = `'CLIENT_HTTP_AUTHENTICATION'` - `SdkErrorCode.ClientHttpForbidden` = `'CLIENT_HTTP_FORBIDDEN'` @@ -351,6 +353,28 @@ server.setRequestHandler('initialize', async (request) => { ... }); server.setNotificationHandler('notifications/message', (notification) => { ... }); ``` +For custom (non-spec) methods, use the 3-arg form `(method, schemas, handler)`: + +```typescript +// v1: Zod schema with method literal +server.setRequestHandler(z.object({ method: z.literal('acme/search'), params: P }), async req => { ... }); + +// v2: method string + schemas object; handler receives parsed params +server.setRequestHandler('acme/search', { params: P, result: R }, async (params, ctx) => { ... }); +client.setNotificationHandler('acme/progress', { params: P }, (params, notification) => { ... }); +``` + +The 3-arg notification handler receives the raw notification as its second argument, so `_meta` is recoverable via `notification.params?._meta`. + +To send a custom-method request, pass a result schema as the second argument to `request()` (and `ctx.mcpReq.send()`): + +```typescript +// v1 +await client.request({ method: 'acme/search', params }, ResultSchema); +// v2 (unchanged; now any Standard Schema, not Zod-only) +await client.request({ method: 'acme/search', params }, ResultSchema); +``` + Schema to method string mapping: | v1 Schema | v2 Method String | @@ -406,9 +430,9 @@ Request/notification params remain fully typed. Remove unused schema imports aft | `ctx.mcpReq.elicitInput(params, options?)` | Elicit user input (form or URL) | `server.elicitInput(...)` from within handler | | `ctx.mcpReq.requestSampling(params, options?)` | Request LLM sampling from client | `server.createMessage(...)` from within handler | -## 11. Schema parameter removed from `request()`, `send()`, and `callTool()` +## 11. Schema parameter removed from `request()`, `send()`, and `callTool()` (spec methods) -`Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` no longer take a Zod result schema argument. The SDK resolves the schema internally from the method name. +For **spec** methods, `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` no longer require a Zod result schema argument. The SDK resolves the schema internally from the method name. ```typescript // v1: schema required @@ -432,6 +456,8 @@ const tool = await client.callTool({ name: 'my-tool', arguments: {} }); | `client.callTool(params, CompatibilityCallToolResultSchema)` | `client.callTool(params)` | | `client.callTool(params, schema, options)` | `client.callTool(params, options)` | +For **custom (non-spec)** methods, keep the result-schema argument — see §9. Only apply the rewrites above when `req.method` is a spec method. + Remove unused schema imports: `CallToolResultSchema`, `CompatibilityCallToolResultSchema`, `ElicitResultSchema`, `CreateMessageResultSchema`, etc., when they were only used in `request()`/`send()`/`callTool()` calls. If `CallToolResultSchema` was used for **runtime validation** (not just as a `request()` argument), replace with the `isCallToolResult` type guard: diff --git a/docs/migration.md b/docs/migration.md index a5b9a3b0d..02f673068 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -366,6 +366,48 @@ server.setNotificationHandler('notifications/message', notification => { The request and notification parameters remain fully typed via `RequestTypeMap` and `NotificationTypeMap`. You no longer need to import the individual `*RequestSchema` or `*NotificationSchema` constants for handler registration. +#### Custom (non-spec) methods + +For vendor-prefixed methods (anything not in the MCP spec), use the 3-arg form: pass the method string, a `{ params, result? }` schemas object, and the handler. Any [Standard Schema](https://standardschema.dev) library works (Zod, Valibot, ArkType). + +**Before (v1):** + +```typescript +const AcmeSearch = z.object({ + method: z.literal('acme/search'), + params: z.object({ query: z.string(), limit: z.number().int() }) +}); +server.setRequestHandler(AcmeSearch, async request => { + return { items: [/* ... */] }; +}); +``` + +**After (v2):** + +```typescript +const SearchParams = z.object({ query: z.string(), limit: z.number().int() }); +const SearchResult = z.object({ items: z.array(z.string()) }); + +server.setRequestHandler('acme/search', { params: SearchParams, result: SearchResult }, async (params, ctx) => { + return { items: [/* ... */] }; +}); +``` + +The handler receives the parsed `params` directly (not the full request envelope). `_meta` is stripped before validation and is available as `ctx.mcpReq._meta`. Supplying `result` types the handler's return value; omit it to return any `Result`. + +For `setNotificationHandler`, the 3-arg handler is `(params, notification) => void`. The raw notification is the second argument, so `_meta` is recoverable via `notification.params?._meta`. + +#### Sending custom-method requests + +`request()` and `ctx.mcpReq.send()` accept a result schema as the second argument; for custom methods this is required: + +```typescript +const result = await client.request({ method: 'acme/search', params: { query: 'mcp', limit: 3 } }, SearchResult); +result.items; // string[] +``` + +For spec methods the 1-arg form still works and the result type is inferred from the method name. + Common method string replacements: | Schema (v1) | Method string (v2) | @@ -384,10 +426,10 @@ Common method string replacements: | `ResourceListChangedNotificationSchema` | `'notifications/resources/list_changed'` | | `PromptListChangedNotificationSchema` | `'notifications/prompts/list_changed'` | -### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer take a schema parameter +### `Protocol.request()`, `ctx.mcpReq.send()`, and `Client.callTool()` no longer require a schema parameter for spec methods -The public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer accept a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas -like `CallToolResultSchema` or `ElicitResultSchema` when making requests. +For **spec** methods, the public `Protocol.request()`, `BaseContext.mcpReq.send()`, and `Client.callTool()` methods no longer require a Zod result schema argument. The SDK now resolves the correct result schema internally based on the method name. This means you no longer need to import result schemas +like `CallToolResultSchema` or `ElicitResultSchema` when making spec-method requests. **`client.request()` — Before (v1):** @@ -444,6 +486,8 @@ const result = await client.callTool({ name: 'my-tool', arguments: {} }); The return type is now inferred from the method name via `ResultTypeMap`. For example, `client.request({ method: 'tools/call', ... })` returns `Promise`. +For **custom (non-spec)** methods, keep the result-schema argument — see [Sending custom-method requests](#sending-custom-method-requests). Only drop the schema when calling a spec method. + If you were using `CallToolResultSchema` for **runtime validation** (not just in `request()`/`callTool()` calls), use the new `isCallToolResult` type guard instead: ```typescript @@ -658,6 +702,7 @@ The new `SdkErrorCode` enum contains string-valued codes for local SDK errors: | `SdkErrorCode.RequestTimeout` | Request timed out waiting for response | | `SdkErrorCode.ConnectionClosed` | Connection was closed | | `SdkErrorCode.SendFailed` | Failed to send message | +| `SdkErrorCode.InvalidResult` | Response result failed local schema validation | | `SdkErrorCode.ClientHttpNotImplemented` | HTTP POST request failed | | `SdkErrorCode.ClientHttpAuthentication` | Server returned 401 after re-authentication | | `SdkErrorCode.ClientHttpForbidden` | Server returned 403 after trying upscoping | diff --git a/examples/client/src/customMethodExample.ts b/examples/client/src/customMethodExample.ts new file mode 100644 index 000000000..a289af0a4 --- /dev/null +++ b/examples/client/src/customMethodExample.ts @@ -0,0 +1,25 @@ +/** + * Custom (non-spec) method example: a client that sends `acme/search` and + * listens for `acme/searchProgress` notifications. + * + * Build `examples/server` first; this client spawns the server via stdio. + */ +import { Client } from '@modelcontextprotocol/client'; +import { StdioClientTransport } from '@modelcontextprotocol/client/stdio'; +import { z } from 'zod/v4'; + +const SearchResult = z.object({ items: z.array(z.string()) }); +const SearchProgressParams = z.object({ stage: z.string(), pct: z.number() }); + +const client = new Client({ name: 'acme-search-client', version: '0.0.0' }); + +client.setNotificationHandler('acme/searchProgress', { params: SearchProgressParams }, params => { + console.log(`[progress] ${params.stage} ${Math.round(params.pct * 100)}%`); +}); + +await client.connect(new StdioClientTransport({ command: 'node', args: ['../server/dist/customMethodExample.js'] })); + +const result = await client.request({ method: 'acme/search', params: { query: 'mcp', limit: 3 } }, SearchResult); +console.log('items:', result.items); + +await client.close(); diff --git a/examples/server/src/customMethodExample.ts b/examples/server/src/customMethodExample.ts new file mode 100644 index 000000000..6968a26e6 --- /dev/null +++ b/examples/server/src/customMethodExample.ts @@ -0,0 +1,23 @@ +/** + * Custom (non-spec) method example: a server that handles a vendor-prefixed + * `acme/search` request and emits `acme/searchProgress` notifications. + * + * Spawned via stdio by `examples/client/src/customMethodExample.ts`; do not run standalone. + */ +import { McpServer } from '@modelcontextprotocol/server'; +import { StdioServerTransport } from '@modelcontextprotocol/server/stdio'; +import { z } from 'zod/v4'; + +const SearchParams = z.object({ query: z.string(), limit: z.number().int().default(10) }); +const SearchResult = z.object({ items: z.array(z.string()) }); + +const mcp = new McpServer({ name: 'acme-search', version: '0.0.0' }); + +mcp.server.setRequestHandler('acme/search', { params: SearchParams, result: SearchResult }, async (params, ctx) => { + await ctx.mcpReq.notify({ method: 'acme/searchProgress', params: { stage: 'start', pct: 0 } }); + const items = Array.from({ length: params.limit }, (_, i) => `${params.query}-${i}`); + await ctx.mcpReq.notify({ method: 'acme/searchProgress', params: { stage: 'done', pct: 1 } }); + return { items }; +}); + +await mcp.connect(new StdioServerTransport()); diff --git a/packages/client/src/client/client.ts b/packages/client/src/client/client.ts index 4a279e532..5fa2e14d9 100644 --- a/packages/client/src/client/client.ts +++ b/packages/client/src/client/client.ts @@ -570,7 +570,7 @@ export class Client extends Protocol { return this._instructions; } - protected assertCapabilityForMethod(method: RequestMethod): void { + protected assertCapabilityForMethod(method: RequestMethod | string): void { switch (method as ClientRequest['method']) { case 'logging/setLevel': { if (!this._serverCapabilities?.logging) { @@ -633,7 +633,7 @@ export class Client extends Protocol { } } - protected assertNotificationCapability(method: NotificationMethod): void { + protected assertNotificationCapability(method: NotificationMethod | string): void { switch (method as ClientNotification['method']) { case 'notifications/roots/list_changed': { if (!this._capabilities.roots?.listChanged) { diff --git a/packages/core/src/errors/sdkErrors.ts b/packages/core/src/errors/sdkErrors.ts index f53c07ccf..8d5e34c14 100644 --- a/packages/core/src/errors/sdkErrors.ts +++ b/packages/core/src/errors/sdkErrors.ts @@ -26,6 +26,8 @@ export enum SdkErrorCode { ConnectionClosed = 'CONNECTION_CLOSED', /** Failed to send message */ SendFailed = 'SEND_FAILED', + /** Response result failed local schema validation */ + InvalidResult = 'INVALID_RESULT', // Transport errors ClientHttpNotImplemented = 'CLIENT_HTTP_NOT_IMPLEMENTED', diff --git a/packages/core/src/exports/public/index.ts b/packages/core/src/exports/public/index.ts index 942bd2368..51a3f5618 100644 --- a/packages/core/src/exports/public/index.ts +++ b/packages/core/src/exports/public/index.ts @@ -45,6 +45,7 @@ export type { NotificationOptions, ProgressCallback, ProtocolOptions, + RequestHandlerSchemas, RequestOptions, ServerContext } from '../../shared/protocol.js'; @@ -137,7 +138,7 @@ export { isTerminal } from '../../experimental/tasks/interfaces.js'; export { InMemoryTaskMessageQueue, InMemoryTaskStore } from '../../experimental/tasks/stores/inMemory.js'; // Validator types and classes -export type { StandardSchemaWithJSON } from '../../util/standardSchema.js'; +export type { StandardSchemaV1, StandardSchemaWithJSON } from '../../util/standardSchema.js'; export { AjvJsonSchemaValidator } from '../../validators/ajvProvider.js'; export type { CfWorkerSchemaDraft } from '../../validators/cfWorkerProvider.js'; // fromJsonSchema is intentionally NOT exported here — the server and client packages diff --git a/packages/core/src/shared/protocol.examples.ts b/packages/core/src/shared/protocol.examples.ts new file mode 100644 index 000000000..ba3a701a2 --- /dev/null +++ b/packages/core/src/shared/protocol.examples.ts @@ -0,0 +1,29 @@ +/** + * Type-checked examples for `protocol.ts`. + * + * These examples are synced into JSDoc comments via the sync-snippets script. + * Each function's region markers define the code snippet that appears in the docs. + * + * @module + */ + +import * as z from 'zod/v4'; + +import type { BaseContext, Protocol } from './protocol.js'; + +/** + * Example: registering a handler for a custom (non-spec) request method. + */ +function Protocol_setRequestHandler_customMethod(protocol: Protocol) { + //#region Protocol_setRequestHandler_customMethod + const SearchParams = z.object({ query: z.string(), limit: z.number().optional() }); + const SearchResult = z.object({ hits: z.array(z.string()) }); + + protocol.setRequestHandler('acme/search', { params: SearchParams, result: SearchResult }, async (params, _ctx) => { + return { hits: [`result for ${params.query}`] }; + }); + //#endregion Protocol_setRequestHandler_customMethod + void protocol; +} + +void Protocol_setRequestHandler_customMethod; diff --git a/packages/core/src/shared/protocol.ts b/packages/core/src/shared/protocol.ts index 799518832..361bd6fc7 100644 --- a/packages/core/src/shared/protocol.ts +++ b/packages/core/src/shared/protocol.ts @@ -44,8 +44,8 @@ import { ProtocolErrorCode, SUPPORTED_PROTOCOL_VERSIONS } from '../types/index.js'; -import type { AnySchema, SchemaOutput } from '../util/schema.js'; -import { parseSchema } from '../util/schema.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; +import { isStandardSchema, validateStandardSchema } from '../util/standardSchema.js'; import type { TaskContext, TaskManagerHost, TaskManagerOptions, TaskRequestOptions } from './taskManager.js'; import { NullTaskManager, TaskManager } from './taskManager.js'; import type { Transport, TransportSendOptions } from './transport.js'; @@ -199,11 +199,21 @@ export type BaseContext = { * Sends a request that relates to the current request being handled. * * This is used by certain transports to correctly associate related messages. + * + * For spec methods the result type is inferred from the method name. + * For custom (non-spec) methods, pass a result schema as the second argument. */ - send: ( - request: { method: M; params?: Record }, - options?: TaskRequestOptions - ) => Promise; + send: { + ( + request: { method: M; params?: Record }, + options?: TaskRequestOptions + ): Promise; + ( + request: Request, + resultSchema: T, + options?: TaskRequestOptions + ): Promise>; + }; /** * Sends a notification that relates to the current request being handled. @@ -294,6 +304,9 @@ type TimeoutInfo = { /** * Implements MCP protocol framing on top of a pluggable transport, including * features like request/response linking, notifications, and progress. + * + * `Protocol` is abstract; `Client` and `Server` are the concrete role-specific + * implementations most code should use. */ export abstract class Protocol { private _transport?: Transport; @@ -550,7 +563,7 @@ export abstract class Protocol { sessionId: capturedTransport?.sessionId, sendNotification: (notification: Notification, options?: NotificationOptions) => this.notification(notification, { ...options, relatedRequestId: request.id }), - sendRequest: (r: Request, resultSchema: U, options?: RequestOptions) => + sendRequest: (r: Request, resultSchema: U, options?: RequestOptions) => this._requestWithSchema(r, resultSchema, { ...options, relatedRequestId: request.id }) }; @@ -596,10 +609,22 @@ export abstract class Protocol { method: request.method, _meta: request.params?._meta, signal: abortController.signal, - send: (r: { method: M; params?: Record }, options?: TaskRequestOptions) => { + // BaseContext.mcpReq.send is declared with two overloads (spec-method-keyed and explicit-schema). Arrow + // literals can't carry overload signatures, so the inferred single-signature type isn't assignable to + // that overloaded property type. The cast is sound: this impl dispatches both overload paths via the + // isStandardSchema guard, and sendRequest validates the result against the resolved schema either way. + send: ((r: Request, schemaOrOptions?: StandardSchemaV1 | TaskRequestOptions, maybeOptions?: TaskRequestOptions) => { + if (isStandardSchema(schemaOrOptions)) { + return sendRequest(r, schemaOrOptions, maybeOptions); + } const resultSchema = getResultSchema(r.method); - return sendRequest(r as Request, resultSchema, options) as Promise; - }, + if (!resultSchema) { + throw new TypeError( + `'${r.method}' is not a spec method; pass a result schema as the second argument to ctx.mcpReq.send().` + ); + } + return sendRequest(r, resultSchema, schemaOrOptions); + }) as BaseContext['mcpReq']['send'], notify: sendNotification }, http: extra?.authInfo ? { authInfo: extra.authInfo } : undefined, @@ -740,14 +765,14 @@ export abstract class Protocol { * * This should be implemented by subclasses. */ - protected abstract assertCapabilityForMethod(method: RequestMethod): void; + protected abstract assertCapabilityForMethod(method: RequestMethod | string): void; /** * A method to check if a notification is supported by the local side, for the given method to be sent. * * This should be implemented by subclasses. */ - protected abstract assertNotificationCapability(method: NotificationMethod): void; + protected abstract assertNotificationCapability(method: NotificationMethod | string): void; /** * A method to check if a request handler is supported by the local side, for the given method to be handled. @@ -773,17 +798,33 @@ export abstract class Protocol { protected abstract assertTaskHandlerCapability(method: string): void; /** - * Sends a request and waits for a response, resolving the result schema - * automatically from the method name. + * Sends a request and waits for a response. + * + * For spec methods the result schema is resolved automatically from the method name + * and the return type is method-keyed. For custom (non-spec) methods, pass a + * `resultSchema` as the second argument; the response is validated against it and + * the return type is inferred from the schema. * * Do not use this method to emit notifications! Use {@linkcode Protocol.notification | notification()} instead. */ request( request: { method: M; params?: Record }, options?: RequestOptions - ): Promise { + ): Promise; + request( + request: Request, + resultSchema: T, + options?: RequestOptions + ): Promise>; + request(request: Request, schemaOrOptions?: StandardSchemaV1 | RequestOptions, maybeOptions?: RequestOptions): Promise { + if (isStandardSchema(schemaOrOptions)) { + return this._requestWithSchema(request, schemaOrOptions, maybeOptions); + } const resultSchema = getResultSchema(request.method); - return this._requestWithSchema(request as Request, resultSchema, options) as Promise; + if (!resultSchema) { + throw new TypeError(`'${request.method}' is not a spec method; pass a result schema as the second argument to request().`); + } + return this._requestWithSchema(request, resultSchema, schemaOrOptions); } /** @@ -792,18 +833,18 @@ export abstract class Protocol { * This is the internal implementation used by SDK methods that need to specify * a particular result schema (e.g., for compatibility or task-specific schemas). */ - protected _requestWithSchema( + protected _requestWithSchema( request: Request, resultSchema: T, options?: RequestOptions - ): Promise> { + ): Promise> { const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; let onAbort: (() => void) | undefined; let cleanupMessageId: number | undefined; // Send the request - return new Promise>((resolve, reject) => { + return new Promise>((resolve, reject) => { const earlyReject = (error: unknown) => { reject(error); }; @@ -815,7 +856,7 @@ export abstract class Protocol { if (this._options?.enforceStrictCapabilities === true) { try { - this.assertCapabilityForMethod(request.method as RequestMethod); + this.assertCapabilityForMethod(request.method); } catch (error) { earlyReject(error); return; @@ -843,7 +884,12 @@ export abstract class Protocol { }; } + let responseReceived = false; + const cancel = (reason: unknown) => { + if (responseReceived) { + return; + } this._progressHandlers.delete(messageId); this._transport @@ -869,21 +915,19 @@ export abstract class Protocol { if (options?.signal?.aborted) { return; } + responseReceived = true; if (response instanceof Error) { return reject(response); } - try { - const parseResult = parseSchema(resultSchema, response.result); + validateStandardSchema(resultSchema, response.result).then(parseResult => { if (parseResult.success) { - resolve(parseResult.data as SchemaOutput); + resolve(parseResult.data); } else { - reject(parseResult.error); + reject(new SdkError(SdkErrorCode.InvalidResult, `Invalid result for ${request.method}: ${parseResult.error}`)); } - } catch (error) { - reject(error); - } + }, reject); }); onAbort = () => cancel(options?.signal?.reason); @@ -950,7 +994,7 @@ export abstract class Protocol { throw new SdkError(SdkErrorCode.NotConnected, 'Not connected'); } - this.assertNotificationCapability(notification.method as NotificationMethod); + this.assertNotificationCapability(notification.method); // Delegate task-related notification routing and JSONRPC building to TaskManager const taskResult = await this._taskManager.processOutboundNotification(notification, options); @@ -1004,27 +1048,73 @@ export abstract class Protocol { * Registers a handler to invoke when this protocol object receives a request with the given method. * * Note that this will replace any previous request handler for the same method. + * + * For spec methods, pass `(method, handler)`; the request is parsed with the spec + * schema and the handler receives the typed `Request`. For custom (non-spec) + * methods, pass `(method, schemas, handler)`; `params` are validated against + * `schemas.params` and the handler receives the parsed params object directly. + * Supplying `schemas.result` types the handler's return value. + * + * @example Custom request method + * ```ts source="./protocol.examples.ts#Protocol_setRequestHandler_customMethod" + * const SearchParams = z.object({ query: z.string(), limit: z.number().optional() }); + * const SearchResult = z.object({ hits: z.array(z.string()) }); + * + * protocol.setRequestHandler('acme/search', { params: SearchParams, result: SearchResult }, async (params, _ctx) => { + * return { hits: [`result for ${params.query}`] }; + * }); + * ``` */ setRequestHandler( method: M, handler: (request: RequestTypeMap[M], ctx: ContextT) => ResultTypeMap[M] | Promise + ): void; + setRequestHandler

( + method: string, + schemas: { params: P; result?: R }, + handler: (params: StandardSchemaV1.InferOutput

, ctx: ContextT) => InferHandlerResult | Promise> + ): void; + setRequestHandler( + method: string, + schemasOrHandler: RequestHandlerSchemas | ((request: unknown, ctx: ContextT) => Result | Promise), + maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise ): void { this.assertRequestHandlerCapability(method); - const schema = getRequestSchema(method); - const stored = (request: JSONRPCRequest, ctx: ContextT): Promise => { - const parsed = schema.parse(request) as RequestTypeMap[M]; - return Promise.resolve(handler(parsed, ctx)); - }; + let stored: (request: JSONRPCRequest, ctx: ContextT) => Promise; + + if (typeof schemasOrHandler === 'function') { + const schema = getRequestSchema(method); + if (!schema) { + throw new TypeError( + `'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().` + ); + } + stored = (request, ctx) => Promise.resolve(schemasOrHandler(schema.parse(request), ctx)); + } else if (maybeHandler) { + stored = async (request, ctx) => { + const userParams = { ...request.params }; + delete userParams._meta; + const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`); + } + return maybeHandler(parsed.data, ctx); + }; + } else { + throw new TypeError('setRequestHandler: handler is required'); + } + this._requestHandlers.set(method, this._wrapHandler(method, stored)); } /** * Hook for subclasses to wrap a registered request handler with role-specific * validation or behavior (e.g. `Server` validates `tools/call` results, `Client` - * validates `elicitation/create` mode and result). The default implementation is identity. + * validates `elicitation/create` mode and result). Runs for both the 2-arg and + * 3-arg registration paths. The default implementation is identity. * - * Subclasses overriding this hook avoid redeclaring `setRequestHandler` and its JSDoc. + * Subclasses overriding this hook avoid redeclaring `setRequestHandler`'s overload set. */ protected _wrapHandler( _method: string, @@ -1036,14 +1126,14 @@ export abstract class Protocol { /** * Removes the request handler for the given method. */ - removeRequestHandler(method: RequestMethod): void { + removeRequestHandler(method: RequestMethod | string): void { this._requestHandlers.delete(method); } /** * Asserts that a request handler has not already been set for the given method, in preparation for a new one being automatically installed. */ - assertCanSetRequestHandler(method: RequestMethod): void { + assertCanSetRequestHandler(method: RequestMethod | string): void { if (this._requestHandlers.has(method)) { throw new Error(`A request handler for ${method} already exists, which would be overridden`); } @@ -1053,27 +1143,77 @@ export abstract class Protocol { * Registers a handler to invoke when this protocol object receives a notification with the given method. * * Note that this will replace any previous notification handler for the same method. + * + * For spec methods, pass `(method, handler)`; the notification is parsed with the + * spec schema. For custom (non-spec) methods, pass `(method, schemas, handler)`; + * `params` are validated against `schemas.params` and the handler receives the + * parsed params object directly. The raw notification is passed as the second + * argument; `_meta` is recoverable via `notification.params?._meta`. */ setNotificationHandler( method: M, handler: (notification: NotificationTypeMap[M]) => void | Promise + ): void; + setNotificationHandler

( + method: string, + schemas: { params: P }, + handler: (params: StandardSchemaV1.InferOutput

, notification: Notification) => void | Promise + ): void; + setNotificationHandler( + method: string, + schemasOrHandler: { params: StandardSchemaV1 } | ((notification: unknown) => void | Promise), + maybeHandler?: (params: unknown, notification: Notification) => void | Promise ): void { - const schema = getNotificationSchema(method); + if (typeof schemasOrHandler === 'function') { + const schema = getNotificationSchema(method); + if (!schema) { + throw new TypeError( + `'${method}' is not a spec notification method; pass schemas as the second argument to setNotificationHandler().` + ); + } + this._notificationHandlers.set(method, notification => Promise.resolve(schemasOrHandler(schema.parse(notification)))); + return; + } - this._notificationHandlers.set(method, notification => { - const parsed = schema.parse(notification); - return Promise.resolve(handler(parsed)); + if (!maybeHandler) { + throw new TypeError('setNotificationHandler: handler is required'); + } + this._notificationHandlers.set(method, async notification => { + const userParams = { ...notification.params }; + delete userParams._meta; + const parsed = await validateStandardSchema(schemasOrHandler.params, userParams); + if (!parsed.success) { + throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for notification ${method}: ${parsed.error}`); + } + await maybeHandler(parsed.data, notification); }); } /** * Removes the notification handler for the given method. */ - removeNotificationHandler(method: NotificationMethod): void { + removeNotificationHandler(method: NotificationMethod | string): void { this._notificationHandlers.delete(method); } } +/** + * Schema bundle accepted by {@linkcode Protocol.setRequestHandler | setRequestHandler}'s 3-arg form. + * + * `params` is required and validates the inbound `request.params`. `result` is optional; + * when supplied it types the handler's return value (no runtime validation is performed + * on the result). + */ +export interface RequestHandlerSchemas< + P extends StandardSchemaV1 = StandardSchemaV1, + R extends StandardSchemaV1 | undefined = StandardSchemaV1 | undefined +> { + params: P; + result?: R; +} + +type InferHandlerResult = R extends StandardSchemaV1 ? StandardSchemaV1.InferOutput : Result; + function isPlainObject(value: unknown): value is Record { return value !== null && typeof value === 'object' && !Array.isArray(value); } diff --git a/packages/core/src/shared/taskManager.ts b/packages/core/src/shared/taskManager.ts index 09d843db7..257dbec82 100644 --- a/packages/core/src/shared/taskManager.ts +++ b/packages/core/src/shared/taskManager.ts @@ -32,6 +32,7 @@ import { TaskStatusNotificationSchema } from '../types/index.js'; import type { AnyObjectSchema, AnySchema, SchemaOutput } from '../util/schema.js'; +import type { StandardSchemaV1 } from '../util/standardSchema.js'; import type { BaseContext, NotificationOptions, RequestOptions } from './protocol.js'; import type { ResponseMessage } from './responseMessage.js'; @@ -39,7 +40,11 @@ import type { ResponseMessage } from './responseMessage.js'; * Host interface for TaskManager to call back into Protocol. @internal */ export interface TaskManagerHost { - request(request: Request, resultSchema: T, options?: RequestOptions): Promise>; + request( + request: Request, + resultSchema: T, + options?: RequestOptions + ): Promise>; notification(notification: Notification, options?: NotificationOptions): Promise; reportError(error: Error): void; removeProgressHandler(token: number): void; @@ -57,7 +62,11 @@ export interface TaskManagerHost { export interface InboundContext { sessionId?: string; sendNotification: (notification: Notification, options?: NotificationOptions) => Promise; - sendRequest: (request: Request, resultSchema: U, options?: RequestOptions) => Promise>; + sendRequest: ( + request: Request, + resultSchema: U, + options?: RequestOptions + ) => Promise>; } /** @@ -67,11 +76,11 @@ export interface InboundContext { export interface InboundResult { taskContext?: BaseContext['task']; sendNotification: (notification: Notification) => Promise; - sendRequest: ( + sendRequest: ( request: Request, resultSchema: U, options?: Omit - ) => Promise>; + ) => Promise>; routeResponse: (message: JSONRPCResponse | JSONRPCErrorResponse) => Promise; hasTaskCreationParams: boolean; /** @@ -274,7 +283,10 @@ export class TaskManager { if (!task) { try { - const result = await host.request(request, resultSchema, options); + // TODO: SchemaOutput (Zod) and StandardSchemaV1.InferOutput (host.request's return) + // resolve to the same type for Zod schemas, but TS can't unify them generically. + // Removing this cast requires aligning ResponseMessage with StandardSchema. + const result = (await host.request(request, resultSchema, options)) as SchemaOutput; yield { type: 'result', result }; } catch (error) { yield { @@ -346,7 +358,8 @@ export class TaskManager { resultSchema: T, options?: RequestOptions ): Promise> { - return this._requireHost.request({ method: 'tasks/result', params }, resultSchema, options); + // TODO: same SchemaOutput vs StandardSchemaV1.InferOutput mismatch as requestStream above. + return this._requireHost.request({ method: 'tasks/result', params }, resultSchema, options) as Promise>; } async listTasks(params?: { cursor?: string }, options?: RequestOptions): Promise> { @@ -563,9 +576,17 @@ export class TaskManager { private wrapSendRequest( relatedTaskId: string, taskStore: RequestTaskStore | undefined, - originalSendRequest: (request: Request, resultSchema: V, options?: RequestOptions) => Promise> - ): (request: Request, resultSchema: V, options?: TaskRequestOptions) => Promise> { - return async (request: Request, resultSchema: V, options?: TaskRequestOptions) => { + originalSendRequest: ( + request: Request, + resultSchema: V, + options?: RequestOptions + ) => Promise> + ): ( + request: Request, + resultSchema: V, + options?: TaskRequestOptions + ) => Promise> { + return async (request: Request, resultSchema: V, options?: TaskRequestOptions) => { const requestOptions: RequestOptions = { ...options }; if (relatedTaskId && !requestOptions.relatedTask) { requestOptions.relatedTask = { taskId: relatedTaskId }; diff --git a/packages/core/src/types/schemas.ts b/packages/core/src/types/schemas.ts index 86acf11d7..246d36976 100644 --- a/packages/core/src/types/schemas.ts +++ b/packages/core/src/types/schemas.ts @@ -2181,10 +2181,13 @@ const resultSchemas: Record = { /** * Gets the Zod schema for validating results of a given request method. + * Returns `undefined` for non-spec methods. * @see getRequestSchema for explanation of the internal type assertion. */ -export function getResultSchema(method: M): z.ZodType { - return resultSchemas[method] as unknown as z.ZodType; +export function getResultSchema(method: M): z.ZodType; +export function getResultSchema(method: string): z.ZodType | undefined; +export function getResultSchema(method: string): z.ZodType | undefined { + return resultSchemas[method as RequestMethod] as unknown as z.ZodType | undefined; } /* Runtime schema lookup — request schemas by method */ @@ -2211,6 +2214,7 @@ const notificationSchemas = buildSchemaMap([...ClientNotificationSchema.options, /** * Gets the Zod schema for a given request method. + * Returns `undefined` for non-spec methods. * The return type is a ZodType that parses to RequestTypeMap[M], allowing callers * to use schema.parse() without needing additional type assertions. * @@ -2219,14 +2223,19 @@ const notificationSchemas = buildSchemaMap([...ClientNotificationSchema.options, * when M is a generic type parameter. Both compute to the same type at * instantiation, but TypeScript can't prove this statically. */ -export function getRequestSchema(method: M): z.ZodType { - return requestSchemas[method] as unknown as z.ZodType; +export function getRequestSchema(method: M): z.ZodType; +export function getRequestSchema(method: string): z.ZodType | undefined; +export function getRequestSchema(method: string): z.ZodType | undefined { + return requestSchemas[method as RequestMethod] as unknown as z.ZodType | undefined; } /** * Gets the Zod schema for a given notification method. + * Returns `undefined` for non-spec methods. * @see getRequestSchema for explanation of the internal type assertion. */ -export function getNotificationSchema(method: M): z.ZodType { - return notificationSchemas[method] as unknown as z.ZodType; +export function getNotificationSchema(method: M): z.ZodType; +export function getNotificationSchema(method: string): z.ZodType | undefined; +export function getNotificationSchema(method: string): z.ZodType | undefined { + return notificationSchemas[method as NotificationMethod] as unknown as z.ZodType | undefined; } diff --git a/packages/core/src/util/standardSchema.ts b/packages/core/src/util/standardSchema.ts index afbd05c68..ee1a63067 100644 --- a/packages/core/src/util/standardSchema.ts +++ b/packages/core/src/util/standardSchema.ts @@ -201,15 +201,15 @@ function formatIssue(issue: StandardSchemaV1.Issue): string { return `${path}: ${issue.message}`; } -export async function validateStandardSchema( +export async function validateStandardSchema( schema: T, data: unknown -): Promise>> { +): Promise>> { const result = await schema['~standard'].validate(data); if (result.issues && result.issues.length > 0) { return { success: false, error: result.issues.map(i => formatIssue(i)).join(', ') }; } - return { success: true, data: (result as StandardSchemaV1.SuccessResult).value as StandardSchemaWithJSON.InferOutput }; + return { success: true, data: (result as StandardSchemaV1.SuccessResult).value as StandardSchemaV1.InferOutput }; } // Prompt argument extraction diff --git a/packages/core/test/shared/customMethods.test.ts b/packages/core/test/shared/customMethods.test.ts new file mode 100644 index 000000000..47e02c9bc --- /dev/null +++ b/packages/core/test/shared/customMethods.test.ts @@ -0,0 +1,200 @@ +import { describe, expect, it } from 'vitest'; +import { z } from 'zod/v4'; + +import { Protocol } from '../../src/shared/protocol.js'; +import type { BaseContext, JSONRPCRequest, Result, StandardSchemaV1 } from '../../src/exports/public/index.js'; +import { ProtocolError } from '../../src/types/index.js'; +import { SdkErrorCode } from '../../src/errors/sdkErrors.js'; +import { InMemoryTransport } from '../../src/util/inMemory.js'; + +class TestProtocol extends Protocol { + protected buildContext(ctx: BaseContext): BaseContext { + return ctx; + } + protected assertCapabilityForMethod(): void {} + protected assertNotificationCapability(): void {} + protected assertRequestHandlerCapability(): void {} + protected assertTaskCapability(): void {} + protected assertTaskHandlerCapability(): void {} +} + +async function pair(): Promise<[TestProtocol, TestProtocol]> { + const [t1, t2] = InMemoryTransport.createLinkedPair(); + const a = new TestProtocol(); + const b = new TestProtocol(); + await a.connect(t1); + await b.connect(t2); + return [a, b]; +} + +describe('Protocol custom-method support', () => { + describe('setRequestHandler 3-arg form', () => { + const SearchParams = z.object({ query: z.string(), limit: z.number().int() }); + const SearchResult = z.object({ items: z.array(z.string()) }); + + it('registers, validates params, and handler receives parsed params', async () => { + const [a, b] = await pair(); + b.setRequestHandler('acme/search', { params: SearchParams, result: SearchResult }, async (params, _ctx) => { + expect(params.query).toBe('hello'); + expect(params.limit).toBe(5); + return { items: [`result for ${params.query}`] }; + }); + + const result = await a.request({ method: 'acme/search', params: { query: 'hello', limit: 5 } }, SearchResult); + expect(result.items).toEqual(['result for hello']); + }); + + it('strips _meta from params before validation', async () => { + const [a, b] = await pair(); + const Strict = z.strictObject({ x: z.number() }); + b.setRequestHandler('acme/strict', { params: Strict }, async params => { + expect(params).toEqual({ x: 1 }); + return {}; + }); + + const result = await a.request({ method: 'acme/strict', params: { x: 1, _meta: { progressToken: 't' } } }, z.object({})); + expect(result).toEqual({}); + }); + + it('rejects invalid params with ProtocolError(InvalidParams)', async () => { + const [a, b] = await pair(); + b.setRequestHandler('acme/search', { params: SearchParams }, async () => ({})); + + await expect(a.request({ method: 'acme/search', params: { query: 'q', limit: 'oops' } }, z.object({}))).rejects.toThrow( + ProtocolError + ); + }); + + it('types handler return from schemas.result', () => { + const p = new TestProtocol(); + p.setRequestHandler('acme/typed', { params: z.object({}), result: SearchResult }, async () => { + return { items: [] }; + }); + // @ts-expect-error wrong return shape when result schema supplied + p.setRequestHandler('acme/typed', { params: z.object({}), result: SearchResult }, async () => ({})); + // No result schema → handler may return any Result + p.setRequestHandler('acme/loose', { params: z.object({}) }, async () => ({}) as Result); + }); + + it('throws TypeError when 2-arg form is used with a non-spec method', () => { + const p = new TestProtocol(); + expect(() => p.setRequestHandler('acme/unknown' as never, () => ({}) as never)).toThrow(TypeError); + }); + + it('routes both 2-arg and 3-arg registration through _wrapHandler', () => { + const seen: string[] = []; + class SpyProtocol extends TestProtocol { + protected override _wrapHandler( + method: string, + handler: (request: JSONRPCRequest, ctx: BaseContext) => Promise + ): (request: JSONRPCRequest, ctx: BaseContext) => Promise { + seen.push(method); + return handler; + } + } + const p = new SpyProtocol(); + p.setRequestHandler('tools/list', () => ({ tools: [] })); + p.setRequestHandler('acme/custom', { params: z.object({}) }, () => ({})); + expect(seen).toContain('tools/list'); + expect(seen).toContain('acme/custom'); + }); + }); + + describe('setNotificationHandler 3-arg form', () => { + it('registers, validates params, handler receives parsed params', async () => { + const [a, b] = await pair(); + const Progress = z.object({ stage: z.string(), pct: z.number() }); + const seen: Array> = []; + b.setNotificationHandler('acme/searchProgress', { params: Progress }, params => { + seen.push(params); + }); + + await a.notification({ method: 'acme/searchProgress', params: { stage: 'fetch', pct: 0.5 } }); + await new Promise(r => setTimeout(r, 0)); + expect(seen).toEqual([{ stage: 'fetch', pct: 0.5 }]); + }); + + it('passes the raw notification (with _meta) as the second handler argument', async () => { + const [a, b] = await pair(); + const Strict = z.strictObject({ stage: z.string() }); + let seenMeta: unknown; + b.setNotificationHandler('acme/searchProgress', { params: Strict }, (params, notification) => { + expect(params).toEqual({ stage: 'fetch' }); + seenMeta = notification.params?._meta; + }); + + await a.notification({ method: 'acme/searchProgress', params: { stage: 'fetch', _meta: { traceId: 't1' } } }); + await new Promise(r => setTimeout(r, 0)); + expect(seenMeta).toEqual({ traceId: 't1' }); + }); + }); + + describe('request() schema overload', () => { + it('validates result against provided schema and types the return', async () => { + const [a, b] = await pair(); + b.setRequestHandler('acme/echo', { params: z.object({ v: z.string() }) }, async params => ({ echoed: params.v })); + + const result = await a.request({ method: 'acme/echo', params: { v: 'x' } }, z.object({ echoed: z.string() })); + expect(result.echoed).toBe('x'); + }); + + it('throws TypeError when 1-arg form is used with a non-spec method', async () => { + const [a] = await pair(); + expect(() => a.request({ method: 'acme/unknown' } as never)).toThrow(TypeError); + }); + + it('rejects with SdkError(InvalidResult) when the response fails the result schema', async () => { + const [a, b] = await pair(); + b.setRequestHandler('acme/bad', { params: z.object({}) }, async () => ({ wrong: 123 })); + + await expect(a.request({ method: 'acme/bad', params: {} }, z.object({ echoed: z.string() }))).rejects.toMatchObject({ + code: SdkErrorCode.InvalidResult + }); + }); + + it('returns the result (and sends no cancellation) if the signal aborts during async result-schema validation', async () => { + const [a, b] = await pair(); + b.setRequestHandler('acme/echo', { params: z.object({}) }, async () => ({ echoed: 'ok' })); + + const cancelled: unknown[] = []; + b.setNotificationHandler('notifications/cancelled', n => { + cancelled.push(n); + }); + + const ac = new AbortController(); + const AsyncEcho: StandardSchemaV1 = { + '~standard': { + version: 1, + vendor: 'test', + validate: value => + new Promise(r => { + ac.abort(); + setTimeout(() => r({ value: value as { echoed: string } }), 0); + }) + } + }; + + const result = await a.request({ method: 'acme/echo', params: {} }, AsyncEcho, { signal: ac.signal }); + expect(result).toEqual({ echoed: 'ok' }); + await new Promise(r => setTimeout(r, 0)); + expect(cancelled).toHaveLength(0); + }); + }); + + describe('ctx.mcpReq.send schema overload', () => { + it('sends a related custom-method request from within a handler', async () => { + const [a, b] = await pair(); + const Pong = z.object({ pong: z.literal(true) }); + + a.setRequestHandler('acme/pong', { params: z.object({}) }, async () => ({ pong: true as const })); + b.setRequestHandler('acme/ping', { params: z.object({}) }, async (_params, ctx) => { + const r = await ctx.mcpReq.send({ method: 'acme/pong', params: {} }, Pong); + expect(r.pong).toBe(true); + return { ok: true }; + }); + + const result = await a.request({ method: 'acme/ping', params: {} }, z.object({ ok: z.boolean() })); + expect(result.ok).toBe(true); + }); + }); +}); diff --git a/packages/server/src/server/server.ts b/packages/server/src/server/server.ts index 8324c8dc1..f6a34f02d 100644 --- a/packages/server/src/server/server.ts +++ b/packages/server/src/server/server.ts @@ -266,7 +266,7 @@ export class Server extends Protocol { }; } - protected assertCapabilityForMethod(method: RequestMethod): void { + protected assertCapabilityForMethod(method: RequestMethod | string): void { switch (method) { case 'sampling/createMessage': { if (!this._clientCapabilities?.sampling) { @@ -299,7 +299,7 @@ export class Server extends Protocol { } } - protected assertNotificationCapability(method: NotificationMethod): void { + protected assertNotificationCapability(method: NotificationMethod | string): void { switch (method) { case 'notifications/message': { if (!this._capabilities.logging) {