Skip to content

Commit e4cb1a9

Browse files
feat(core): accept StandardSchemaV1 in setCustom*/sendCustom* (not just Zod)
AnySchema is now StandardSchemaV1 (which Zod schemas implement), and parseSchema routes through validateStandardSchema. This lets specTypeSchema() output and other non-Zod Standard Schemas be passed directly to setCustomRequestHandler / setCustomNotificationHandler / sendCustomRequest / sendCustomNotification. parseSchema is now async; all 16 callers were already in async contexts except _setupListChangedHandler, which is now async too (called from connect()). validateStandardSchema constraint relaxed from StandardSchemaWithJSON to StandardSchemaV1. isSchemaBundle was already StandardSchema-aware.
1 parent cdaca38 commit e4cb1a9

6 files changed

Lines changed: 103 additions & 43 deletions

File tree

packages/client/src/client/client.ts

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -255,23 +255,23 @@ export class Client extends Protocol<ClientContext> {
255255
* Handlers are silently skipped if the server doesn't advertise the corresponding listChanged capability.
256256
* @internal
257257
*/
258-
private _setupListChangedHandlers(config: ListChangedHandlers): void {
258+
private async _setupListChangedHandlers(config: ListChangedHandlers): Promise<void> {
259259
if (config.tools && this._serverCapabilities?.tools?.listChanged) {
260-
this._setupListChangedHandler('tools', 'notifications/tools/list_changed', config.tools, async () => {
260+
await this._setupListChangedHandler('tools', 'notifications/tools/list_changed', config.tools, async () => {
261261
const result = await this.listTools();
262262
return result.tools;
263263
});
264264
}
265265

266266
if (config.prompts && this._serverCapabilities?.prompts?.listChanged) {
267-
this._setupListChangedHandler('prompts', 'notifications/prompts/list_changed', config.prompts, async () => {
267+
await this._setupListChangedHandler('prompts', 'notifications/prompts/list_changed', config.prompts, async () => {
268268
const result = await this.listPrompts();
269269
return result.prompts;
270270
});
271271
}
272272

273273
if (config.resources && this._serverCapabilities?.resources?.listChanged) {
274-
this._setupListChangedHandler('resources', 'notifications/resources/list_changed', config.resources, async () => {
274+
await this._setupListChangedHandler('resources', 'notifications/resources/list_changed', config.resources, async () => {
275275
const result = await this.listResources();
276276
return result.resources;
277277
});
@@ -339,7 +339,7 @@ export class Client extends Protocol<ClientContext> {
339339
): void {
340340
if (method === 'elicitation/create') {
341341
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
342-
const validatedRequest = parseSchema(ElicitRequestSchema, request);
342+
const validatedRequest = await parseSchema(ElicitRequestSchema, request);
343343
if (!validatedRequest.success) {
344344
// Type guard: if success is false, error is guaranteed to exist
345345
const errorMessage =
@@ -363,7 +363,7 @@ export class Client extends Protocol<ClientContext> {
363363

364364
// When task creation is requested, validate and return CreateTaskResult
365365
if (params.task) {
366-
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
366+
const taskValidationResult = await parseSchema(CreateTaskResultSchema, result);
367367
if (!taskValidationResult.success) {
368368
const errorMessage =
369369
taskValidationResult.error instanceof Error
@@ -375,7 +375,7 @@ export class Client extends Protocol<ClientContext> {
375375
}
376376

377377
// For non-task requests, validate against ElicitResultSchema
378-
const validationResult = parseSchema(ElicitResultSchema, result);
378+
const validationResult = await parseSchema(ElicitResultSchema, result);
379379
if (!validationResult.success) {
380380
// Type guard: if success is false, error is guaranteed to exist
381381
const errorMessage =
@@ -409,7 +409,7 @@ export class Client extends Protocol<ClientContext> {
409409

410410
if (method === 'sampling/createMessage') {
411411
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ClientContext): Promise<ClientResult> => {
412-
const validatedRequest = parseSchema(CreateMessageRequestSchema, request);
412+
const validatedRequest = await parseSchema(CreateMessageRequestSchema, request);
413413
if (!validatedRequest.success) {
414414
const errorMessage =
415415
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
@@ -422,7 +422,7 @@ export class Client extends Protocol<ClientContext> {
422422

423423
// When task creation is requested, validate and return CreateTaskResult
424424
if (params.task) {
425-
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
425+
const taskValidationResult = await parseSchema(CreateTaskResultSchema, result);
426426
if (!taskValidationResult.success) {
427427
const errorMessage =
428428
taskValidationResult.error instanceof Error
@@ -436,7 +436,7 @@ export class Client extends Protocol<ClientContext> {
436436
// For non-task requests, validate against appropriate schema based on tools presence
437437
const hasTools = params.tools || params.toolChoice;
438438
const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema;
439-
const validationResult = parseSchema(resultSchema, result);
439+
const validationResult = await parseSchema(resultSchema, result);
440440
if (!validationResult.success) {
441441
const errorMessage =
442442
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);
@@ -538,7 +538,7 @@ export class Client extends Protocol<ClientContext> {
538538

539539
// Set up list changed handlers now that we know server capabilities
540540
if (this._pendingListChangedConfig) {
541-
this._setupListChangedHandlers(this._pendingListChangedConfig);
541+
await this._setupListChangedHandlers(this._pendingListChangedConfig);
542542
this._pendingListChangedConfig = undefined;
543543
}
544544
} catch (error) {
@@ -1005,14 +1005,14 @@ export class Client extends Protocol<ClientContext> {
10051005
* Set up a single list changed handler.
10061006
* @internal
10071007
*/
1008-
private _setupListChangedHandler<T>(
1008+
private async _setupListChangedHandler<T>(
10091009
listType: string,
10101010
notificationMethod: NotificationMethod,
10111011
options: ListChangedOptions<T>,
10121012
fetcher: () => Promise<T[]>
1013-
): void {
1014-
// Validate options using Zod schema (validates autoRefresh and debounceMs)
1015-
const parseResult = parseSchema(ListChangedOptionsBaseSchema, options);
1013+
): Promise<void> {
1014+
// Validate options (autoRefresh and debounceMs)
1015+
const parseResult = await parseSchema(ListChangedOptionsBaseSchema, options);
10161016
if (!parseResult.success) {
10171017
throw new Error(`Invalid ${listType} listChanged options: ${parseResult.error.message}`);
10181018
}

packages/core/src/shared/protocol.ts

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -867,7 +867,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
867867
reject(error);
868868
};
869869

870-
this._responseHandlers.set(messageId, response => {
870+
this._responseHandlers.set(messageId, async response => {
871871
if (options?.signal?.aborted) {
872872
return;
873873
}
@@ -877,7 +877,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
877877
}
878878

879879
try {
880-
const parseResult = parseSchema(resultSchema, response.result);
880+
const parseResult = await parseSchema(resultSchema, response.result);
881881
if (parseResult.success) {
882882
resolve(parseResult.data as SchemaOutput<T>);
883883
} else {
@@ -1079,14 +1079,14 @@ export abstract class Protocol<ContextT extends BaseContext> {
10791079
if (isRequestMethod(method)) {
10801080
throw new Error(`"${method}" is a standard MCP request method. Use setRequestHandler() instead.`);
10811081
}
1082-
this._requestHandlers.set(method, (request, ctx) => {
1082+
this._requestHandlers.set(method, async (request, ctx) => {
10831083
const { _meta, ...userParams } = (request.params ?? {}) as Record<string, unknown>;
10841084
void _meta;
1085-
const parsed = parseSchema(paramsSchema, userParams);
1085+
const parsed = await parseSchema(paramsSchema, userParams);
10861086
if (!parsed.success) {
10871087
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`);
10881088
}
1089-
return Promise.resolve(handler(parsed.data, ctx));
1089+
return handler(parsed.data, ctx);
10901090
});
10911091
}
10921092

@@ -1119,14 +1119,14 @@ export abstract class Protocol<ContextT extends BaseContext> {
11191119
if (isNotificationMethod(method)) {
11201120
throw new Error(`"${method}" is a standard MCP notification method. Use setNotificationHandler() instead.`);
11211121
}
1122-
this._notificationHandlers.set(method, notification => {
1122+
this._notificationHandlers.set(method, async notification => {
11231123
const { _meta, ...userParams } = (notification.params ?? {}) as Record<string, unknown>;
11241124
void _meta;
1125-
const parsed = parseSchema(paramsSchema, userParams);
1125+
const parsed = await parseSchema(paramsSchema, userParams);
11261126
if (!parsed.success) {
11271127
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`);
11281128
}
1129-
return Promise.resolve(handler(parsed.data));
1129+
return handler(parsed.data);
11301130
});
11311131
}
11321132

@@ -1178,7 +1178,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
11781178
): Promise<unknown> {
11791179
let resultSchema: AnySchema;
11801180
if (isSchemaBundle(schemaOrBundle)) {
1181-
const parsed = parseSchema(schemaOrBundle.params, params);
1181+
const parsed = await parseSchema(schemaOrBundle.params, params);
11821182
if (!parsed.success) {
11831183
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`);
11841184
}
@@ -1217,7 +1217,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
12171217
): Promise<void> {
12181218
let options: NotificationOptions | undefined;
12191219
if (schemasOrOptions && 'params' in schemasOrOptions) {
1220-
const parsed = parseSchema(schemasOrOptions.params, params);
1220+
const parsed = await parseSchema(schemasOrOptions.params, params);
12211221
if (!parsed.success) {
12221222
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error.message}`);
12231223
}

packages/core/src/util/schema.ts

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,47 @@
11
/**
2-
* Internal Zod schema utilities for protocol handling.
2+
* Standard Schema utilities for protocol handling.
33
* These are used internally by the SDK for protocol message validation.
44
*/
55

6-
import * as z from 'zod/v4';
6+
import type * as z from 'zod/v4';
7+
8+
import type { StandardSchemaV1 } from './standardSchema.js';
9+
import { validateStandardSchema } from './standardSchema.js';
710

811
/**
9-
* Base type for any Zod schema.
12+
* Base type for any schema accepted by the SDK's user-facing schema parameters.
13+
*
14+
* This is the Standard Schema interface (https://standardschema.dev), which Zod, Valibot, ArkType
15+
* and others implement. Zod schemas satisfy this constraint natively.
1016
*/
11-
export type AnySchema = z.core.$ZodType;
17+
export type AnySchema = StandardSchemaV1;
1218

1319
/**
1420
* A Zod schema for objects specifically.
21+
*
22+
* Retained for internal use where the SDK needs Zod-specific introspection (e.g. converting a tool
23+
* input schema to JSON Schema). Not used for user-facing schema parameters.
1524
*/
1625
export type AnyObjectSchema = z.core.$ZodObject;
1726

1827
/**
19-
* Extracts the output type from a Zod schema.
28+
* Extracts the output type from a Standard Schema.
2029
*/
21-
export type SchemaOutput<T extends AnySchema> = z.output<T>;
30+
export type SchemaOutput<T extends AnySchema> = StandardSchemaV1.InferOutput<T>;
2231

2332
/**
24-
* Parses data against a Zod schema (synchronous).
25-
* Returns a discriminated union with success/error.
33+
* Parses data against a Standard Schema.
34+
*
35+
* Returns a discriminated union with success/error. The error is a plain `Error` whose `message`
36+
* is a comma-separated list of issues, so callers can interpolate it directly.
2637
*/
27-
export function parseSchema<T extends AnySchema>(
38+
export async function parseSchema<T extends AnySchema>(
2839
schema: T,
2940
data: unknown
30-
): { success: true; data: z.output<T> } | { success: false; error: z.core.$ZodError } {
31-
return z.safeParse(schema, data);
41+
): Promise<{ success: true; data: SchemaOutput<T> } | { success: false; error: Error }> {
42+
const result = await validateStandardSchema(schema, data);
43+
if (result.success) {
44+
return result;
45+
}
46+
return { success: false, error: new Error(result.error) };
3247
}

packages/core/src/util/standardSchema.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,15 @@ function formatIssue(issue: StandardSchemaV1.Issue): string {
169169
return `${path}: ${issue.message}`;
170170
}
171171

172-
export async function validateStandardSchema<T extends StandardSchemaWithJSON>(
172+
export async function validateStandardSchema<T extends StandardSchemaV1>(
173173
schema: T,
174174
data: unknown
175-
): Promise<StandardSchemaValidationResult<StandardSchemaWithJSON.InferOutput<T>>> {
175+
): Promise<StandardSchemaValidationResult<StandardSchemaV1.InferOutput<T>>> {
176176
const result = await schema['~standard'].validate(data);
177177
if (result.issues && result.issues.length > 0) {
178178
return { success: false, error: result.issues.map(i => formatIssue(i)).join(', ') };
179179
}
180-
return { success: true, data: (result as StandardSchemaV1.SuccessResult<unknown>).value as StandardSchemaWithJSON.InferOutput<T> };
180+
return { success: true, data: (result as StandardSchemaV1.SuccessResult<unknown>).value as StandardSchemaV1.InferOutput<T> };
181181
}
182182

183183
// Prompt argument extraction

packages/core/test/shared/customMethods.test.ts

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import type { BaseContext } from '../../src/shared/protocol.js';
66
import { Protocol } from '../../src/shared/protocol.js';
77
import { ProtocolError, ProtocolErrorCode } from '../../src/types/index.js';
88
import { InMemoryTransport } from '../../src/util/inMemory.js';
9+
import type { StandardSchemaV1 } from '../../src/util/standardSchema.js';
910

1011
class TestProtocol extends Protocol<BaseContext> {
1112
protected assertCapabilityForMethod(): void {}
@@ -273,3 +274,47 @@ describe('sendCustomNotification', () => {
273274
expect(count).toBe(1);
274275
});
275276
});
277+
278+
describe('setCustom* — accepts non-Zod StandardSchemaV1', () => {
279+
function makeStandardSchema<T>(validate: (v: unknown) => T | { issues: ReadonlyArray<{ message: string }> }): StandardSchemaV1<T> {
280+
return {
281+
'~standard': {
282+
version: 1 as const,
283+
vendor: 'test',
284+
types: undefined as unknown as { input: T; output: T },
285+
validate: (v: unknown) => {
286+
const r = validate(v);
287+
return typeof r === 'object' && r !== null && 'issues' in r ? r : { value: r as T };
288+
}
289+
}
290+
};
291+
}
292+
293+
test('setCustomRequestHandler validates via ~standard.validate (no Zod)', async () => {
294+
const [client, server] = await linkedPair();
295+
296+
type Params = { n: number };
297+
const ParamsSchema = makeStandardSchema<Params>(v =>
298+
typeof v === 'object' && v !== null && typeof (v as Params).n === 'number'
299+
? (v as Params)
300+
: { issues: [{ message: 'n must be a number' }] }
301+
);
302+
const ResultSchema = makeStandardSchema<{ doubled: number }>(v =>
303+
typeof v === 'object' && v !== null && typeof (v as { doubled: number }).doubled === 'number'
304+
? (v as { doubled: number })
305+
: { issues: [{ message: 'doubled must be a number' }] }
306+
);
307+
308+
server.setCustomRequestHandler('test/double', ParamsSchema, async (params: Params) => ({ doubled: params.n * 2 }));
309+
310+
const result = await client.sendCustomRequest('test/double', { n: 21 }, ResultSchema);
311+
expect(result.doubled).toBe(42);
312+
313+
await expect(client.sendCustomRequest('test/double', { n: 'nope' }, ResultSchema)).rejects.toSatisfy(
314+
(e: unknown) => e instanceof ProtocolError && /n must be a number/.test(e.message)
315+
);
316+
317+
await client.close();
318+
await server.close();
319+
});
320+
});

packages/server/src/server/server.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ export class Server extends Protocol<ServerContext> {
145145
const transportSessionId: string | undefined =
146146
ctx.sessionId || (ctx.http?.req?.headers.get('mcp-session-id') as string) || undefined;
147147
const { level } = request.params;
148-
const parseResult = parseSchema(LoggingLevelSchema, level);
148+
const parseResult = await parseSchema(LoggingLevelSchema, level);
149149
if (parseResult.success) {
150150
this._loggingLevels.set(transportSessionId, parseResult.data);
151151
}
@@ -228,7 +228,7 @@ export class Server extends Protocol<ServerContext> {
228228
): void {
229229
if (method === 'tools/call') {
230230
const wrappedHandler = async (request: RequestTypeMap[M], ctx: ServerContext): Promise<ServerResult> => {
231-
const validatedRequest = parseSchema(CallToolRequestSchema, request);
231+
const validatedRequest = await parseSchema(CallToolRequestSchema, request);
232232
if (!validatedRequest.success) {
233233
const errorMessage =
234234
validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error);
@@ -241,7 +241,7 @@ export class Server extends Protocol<ServerContext> {
241241

242242
// When task creation is requested, validate and return CreateTaskResult
243243
if (params.task) {
244-
const taskValidationResult = parseSchema(CreateTaskResultSchema, result);
244+
const taskValidationResult = await parseSchema(CreateTaskResultSchema, result);
245245
if (!taskValidationResult.success) {
246246
const errorMessage =
247247
taskValidationResult.error instanceof Error
@@ -253,7 +253,7 @@ export class Server extends Protocol<ServerContext> {
253253
}
254254

255255
// For non-task requests, validate against CallToolResultSchema
256-
const validationResult = parseSchema(CallToolResultSchema, result);
256+
const validationResult = await parseSchema(CallToolResultSchema, result);
257257
if (!validationResult.success) {
258258
const errorMessage =
259259
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);

0 commit comments

Comments
 (0)