Skip to content

Commit 5c25abd

Browse files
refactor: remove fixable casts (widen get*Schema/assert*, isStandardSchema guard, union impl sigs)
1 parent 0b24200 commit 5c25abd

5 files changed

Lines changed: 70 additions & 59 deletions

File tree

packages/client/src/client/client.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import type {
2424
NotificationMethod,
2525
ProtocolOptions,
2626
ReadResourceRequest,
27-
RequestMethod,
2827
RequestOptions,
2928
Result,
3029
ServerCapabilities,
@@ -565,7 +564,7 @@ export class Client extends Protocol<ClientContext> {
565564
return this._instructions;
566565
}
567566

568-
protected assertCapabilityForMethod(method: RequestMethod): void {
567+
protected assertCapabilityForMethod(method: string): void {
569568
switch (method as ClientRequest['method']) {
570569
case 'logging/setLevel': {
571570
if (!this._serverCapabilities?.logging) {
@@ -628,7 +627,7 @@ export class Client extends Protocol<ClientContext> {
628627
}
629628
}
630629

631-
protected assertNotificationCapability(method: NotificationMethod): void {
630+
protected assertNotificationCapability(method: string): void {
632631
switch (method as ClientNotification['method']) {
633632
case 'notifications/roots/list_changed': {
634633
if (!this._capabilities.roots?.listChanged) {

packages/core/src/shared/protocol.ts

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import {
4545
SUPPORTED_PROTOCOL_VERSIONS
4646
} from '../types/index.js';
4747
import type { StandardSchemaV1 } from '../util/standardSchema.js';
48-
import { validateStandardSchema } from '../util/standardSchema.js';
48+
import { isStandardSchema, validateStandardSchema } from '../util/standardSchema.js';
4949
import type { TaskContext, TaskManagerHost, TaskManagerOptions, TaskRequestOptions } from './taskManager.js';
5050
import { NullTaskManager, TaskManager } from './taskManager.js';
5151
import type { Transport, TransportSendOptions } from './transport.js';
@@ -609,17 +609,18 @@ export abstract class Protocol<ContextT extends BaseContext> {
609609
method: request.method,
610610
_meta: request.params?._meta,
611611
signal: abortController.signal,
612+
// Arrow literals can't carry overload signatures; cast asserts this impl matches the overloaded property type.
612613
send: ((r: Request, schemaOrOptions?: StandardSchemaV1 | TaskRequestOptions, maybeOptions?: TaskRequestOptions) => {
613-
if (schemaOrOptions != null && typeof schemaOrOptions === 'object' && '~standard' in schemaOrOptions) {
614+
if (isStandardSchema(schemaOrOptions)) {
614615
return sendRequest(r, schemaOrOptions, maybeOptions);
615616
}
616-
const resultSchema = getResultSchema(r.method as RequestMethod);
617+
const resultSchema = getResultSchema(r.method);
617618
if (!resultSchema) {
618619
throw new TypeError(
619620
`'${r.method}' is not a spec method; pass a result schema as the second argument to ctx.mcpReq.send().`
620621
);
621622
}
622-
return sendRequest(r, resultSchema, schemaOrOptions as TaskRequestOptions | undefined);
623+
return sendRequest(r, resultSchema, schemaOrOptions);
623624
}) as BaseContext['mcpReq']['send'],
624625
notify: sendNotification
625626
},
@@ -761,14 +762,14 @@ export abstract class Protocol<ContextT extends BaseContext> {
761762
*
762763
* This should be implemented by subclasses.
763764
*/
764-
protected abstract assertCapabilityForMethod(method: RequestMethod): void;
765+
protected abstract assertCapabilityForMethod(method: string): void;
765766

766767
/**
767768
* A method to check if a notification is supported by the local side, for the given method to be sent.
768769
*
769770
* This should be implemented by subclasses.
770771
*/
771-
protected abstract assertNotificationCapability(method: NotificationMethod): void;
772+
protected abstract assertNotificationCapability(method: string): void;
772773

773774
/**
774775
* A method to check if a request handler is supported by the local side, for the given method to be handled.
@@ -813,10 +814,10 @@ export abstract class Protocol<ContextT extends BaseContext> {
813814
options?: RequestOptions
814815
): Promise<StandardSchemaV1.InferOutput<T>>;
815816
request(request: Request, schemaOrOptions?: StandardSchemaV1 | RequestOptions, maybeOptions?: RequestOptions): Promise<unknown> {
816-
if (schemaOrOptions != null && typeof schemaOrOptions === 'object' && '~standard' in schemaOrOptions) {
817+
if (isStandardSchema(schemaOrOptions)) {
817818
return this._requestWithSchema(request, schemaOrOptions, maybeOptions);
818819
}
819-
const resultSchema = getResultSchema(request.method as RequestMethod);
820+
const resultSchema = getResultSchema(request.method);
820821
if (!resultSchema) {
821822
throw new TypeError(`'${request.method}' is not a spec method; pass a result schema as the second argument to request().`);
822823
}
@@ -852,7 +853,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
852853

853854
if (this._options?.enforceStrictCapabilities === true) {
854855
try {
855-
this.assertCapabilityForMethod(request.method as RequestMethod);
856+
this.assertCapabilityForMethod(request.method);
856857
} catch (error) {
857858
earlyReject(error);
858859
return;
@@ -986,7 +987,7 @@ export abstract class Protocol<ContextT extends BaseContext> {
986987
throw new SdkError(SdkErrorCode.NotConnected, 'Not connected');
987988
}
988989

989-
this.assertNotificationCapability(notification.method as NotificationMethod);
990+
this.assertNotificationCapability(notification.method);
990991

991992
// Delegate task-related notification routing and JSONRPC building to TaskManager
992993
const taskResult = await this._taskManager.processOutboundNotification(notification, options);
@@ -1069,35 +1070,35 @@ export abstract class Protocol<ContextT extends BaseContext> {
10691070
ctx: ContextT
10701071
) => RequestHandlerSchemas.InferResult<R> | Promise<RequestHandlerSchemas.InferResult<R>>
10711072
): void;
1072-
setRequestHandler(method: string, schemasOrHandler: unknown, maybeHandler?: unknown): void {
1073+
setRequestHandler(
1074+
method: string,
1075+
schemasOrHandler: RequestHandlerSchemas | ((request: unknown, ctx: ContextT) => Result | Promise<Result>),
1076+
maybeHandler?: (params: unknown, ctx: ContextT) => Result | Promise<Result>
1077+
): void {
10731078
this.assertRequestHandlerCapability(method);
10741079

10751080
let stored: (request: JSONRPCRequest, ctx: ContextT) => Promise<Result>;
10761081

1077-
if (maybeHandler === undefined) {
1078-
const handler = schemasOrHandler as (request: unknown, ctx: ContextT) => Result | Promise<Result>;
1079-
const schema = getRequestSchema(method as RequestMethod);
1082+
if (typeof schemasOrHandler === 'function') {
1083+
const schema = getRequestSchema(method);
10801084
if (!schema) {
10811085
throw new TypeError(
10821086
`'${method}' is not a spec request method; pass schemas as the second argument to setRequestHandler().`
10831087
);
10841088
}
1085-
stored = (request, ctx) => {
1086-
const parsed = schema.parse(request);
1087-
return Promise.resolve(handler(parsed, ctx));
1088-
};
1089-
} else {
1090-
const schemas = schemasOrHandler as RequestHandlerSchemas;
1091-
const handler = maybeHandler as (params: unknown, ctx: ContextT) => Result | Promise<Result>;
1089+
stored = (request, ctx) => Promise.resolve(schemasOrHandler(schema.parse(request), ctx));
1090+
} else if (maybeHandler) {
10921091
stored = async (request, ctx) => {
1093-
const userParams = { ...((request.params ?? {}) as Record<string, unknown>) };
1092+
const userParams = { ...request.params };
10941093
delete userParams._meta;
1095-
const parsed = await validateStandardSchema(schemas.params, userParams);
1094+
const parsed = await validateStandardSchema(schemasOrHandler.params, userParams);
10961095
if (!parsed.success) {
10971096
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for ${method}: ${parsed.error}`);
10981097
}
1099-
return handler(parsed.data, ctx);
1098+
return maybeHandler(parsed.data, ctx);
11001099
};
1100+
} else {
1101+
throw new TypeError('setRequestHandler: handler is required');
11011102
}
11021103

11031104
this._requestHandlers.set(method, this._wrapHandler(method, stored));
@@ -1153,32 +1154,33 @@ export abstract class Protocol<ContextT extends BaseContext> {
11531154
schemas: { params: P },
11541155
handler: (params: StandardSchemaV1.InferOutput<P>) => void | Promise<void>
11551156
): void;
1156-
setNotificationHandler(method: string, schemasOrHandler: unknown, maybeHandler?: unknown): void {
1157-
if (maybeHandler !== undefined) {
1158-
const schemas = schemasOrHandler as { params: StandardSchemaV1 };
1159-
const handler = maybeHandler as (params: unknown) => void | Promise<void>;
1160-
this._notificationHandlers.set(method, async notification => {
1161-
const userParams = { ...((notification.params ?? {}) as Record<string, unknown>) };
1162-
delete userParams._meta;
1163-
const parsed = await validateStandardSchema(schemas.params, userParams);
1164-
if (!parsed.success) {
1165-
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for notification ${method}: ${parsed.error}`);
1166-
}
1167-
await handler(parsed.data);
1168-
});
1157+
setNotificationHandler(
1158+
method: string,
1159+
schemasOrHandler: { params: StandardSchemaV1 } | ((notification: unknown) => void | Promise<void>),
1160+
maybeHandler?: (params: unknown) => void | Promise<void>
1161+
): void {
1162+
if (typeof schemasOrHandler === 'function') {
1163+
const schema = getNotificationSchema(method);
1164+
if (!schema) {
1165+
throw new TypeError(
1166+
`'${method}' is not a spec notification method; pass schemas as the second argument to setNotificationHandler().`
1167+
);
1168+
}
1169+
this._notificationHandlers.set(method, notification => Promise.resolve(schemasOrHandler(schema.parse(notification))));
11691170
return;
11701171
}
11711172

1172-
const handler = schemasOrHandler as (notification: unknown) => void | Promise<void>;
1173-
const schema = getNotificationSchema(method as NotificationMethod);
1174-
if (!schema) {
1175-
throw new TypeError(
1176-
`'${method}' is not a spec notification method; pass schemas as the second argument to setNotificationHandler().`
1177-
);
1173+
if (!maybeHandler) {
1174+
throw new TypeError('setNotificationHandler: handler is required');
11781175
}
1179-
this._notificationHandlers.set(method, notification => {
1180-
const parsed = schema.parse(notification);
1181-
return Promise.resolve(handler(parsed));
1176+
this._notificationHandlers.set(method, async notification => {
1177+
const userParams = { ...notification.params };
1178+
delete userParams._meta;
1179+
const parsed = await validateStandardSchema(schemasOrHandler.params, userParams);
1180+
if (!parsed.success) {
1181+
throw new ProtocolError(ProtocolErrorCode.InvalidParams, `Invalid params for notification ${method}: ${parsed.error}`);
1182+
}
1183+
await maybeHandler(parsed.data);
11821184
});
11831185
}
11841186

packages/core/src/shared/taskManager.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ export class TaskManager {
283283

284284
if (!task) {
285285
try {
286+
// TODO: SchemaOutput<T> (Zod) and StandardSchemaV1.InferOutput<T> (host.request's return)
287+
// resolve to the same type for Zod schemas, but TS can't unify them generically.
288+
// Removing this cast requires aligning ResponseMessage<T extends Result> with StandardSchema.
286289
const result = (await host.request(request, resultSchema, options)) as SchemaOutput<T>;
287290
yield { type: 'result', result };
288291
} catch (error) {
@@ -355,6 +358,7 @@ export class TaskManager {
355358
resultSchema: T,
356359
options?: RequestOptions
357360
): Promise<SchemaOutput<T>> {
361+
// TODO: same SchemaOutput<T> vs StandardSchemaV1.InferOutput<T> mismatch as requestStream above.
358362
return this._requireHost.request({ method: 'tasks/result', params }, resultSchema, options) as Promise<SchemaOutput<T>>;
359363
}
360364

packages/core/src/types/schemas.ts

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2181,10 +2181,13 @@ const resultSchemas: Record<string, z.core.$ZodType> = {
21812181

21822182
/**
21832183
* Gets the Zod schema for validating results of a given request method.
2184+
* Returns `undefined` for non-spec methods.
21842185
* @see getRequestSchema for explanation of the internal type assertion.
21852186
*/
2186-
export function getResultSchema<M extends RequestMethod>(method: M): z.ZodType<ResultTypeMap[M]> {
2187-
return resultSchemas[method] as unknown as z.ZodType<ResultTypeMap[M]>;
2187+
export function getResultSchema<M extends RequestMethod>(method: M): z.ZodType<ResultTypeMap[M]>;
2188+
export function getResultSchema(method: string): z.ZodType | undefined;
2189+
export function getResultSchema(method: string): z.ZodType | undefined {
2190+
return resultSchemas[method as RequestMethod] as unknown as z.ZodType | undefined;
21882191
}
21892192

21902193
/* Runtime schema lookup — request schemas by method */
@@ -2219,14 +2222,19 @@ const notificationSchemas = buildSchemaMap([...ClientNotificationSchema.options,
22192222
* when M is a generic type parameter. Both compute to the same type at
22202223
* instantiation, but TypeScript can't prove this statically.
22212224
*/
2222-
export function getRequestSchema<M extends RequestMethod>(method: M): z.ZodType<RequestTypeMap[M]> {
2223-
return requestSchemas[method] as unknown as z.ZodType<RequestTypeMap[M]>;
2225+
export function getRequestSchema<M extends RequestMethod>(method: M): z.ZodType<RequestTypeMap[M]>;
2226+
export function getRequestSchema(method: string): z.ZodType | undefined;
2227+
export function getRequestSchema(method: string): z.ZodType | undefined {
2228+
return requestSchemas[method as RequestMethod] as unknown as z.ZodType | undefined;
22242229
}
22252230

22262231
/**
22272232
* Gets the Zod schema for a given notification method.
2233+
* Returns `undefined` for non-spec methods.
22282234
* @see getRequestSchema for explanation of the internal type assertion.
22292235
*/
2230-
export function getNotificationSchema<M extends NotificationMethod>(method: M): z.ZodType<NotificationTypeMap[M]> {
2231-
return notificationSchemas[method] as unknown as z.ZodType<NotificationTypeMap[M]>;
2236+
export function getNotificationSchema<M extends NotificationMethod>(method: M): z.ZodType<NotificationTypeMap[M]>;
2237+
export function getNotificationSchema(method: string): z.ZodType | undefined;
2238+
export function getNotificationSchema(method: string): z.ZodType | undefined {
2239+
return notificationSchemas[method as NotificationMethod] as unknown as z.ZodType | undefined;
22322240
}

packages/server/src/server/server.ts

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ import type {
1919
LoggingLevel,
2020
LoggingMessageNotification,
2121
MessageExtraInfo,
22-
NotificationMethod,
2322
NotificationOptions,
2423
ProtocolOptions,
25-
RequestMethod,
2624
RequestOptions,
2725
ResourceUpdatedNotification,
2826
Result,
@@ -266,7 +264,7 @@ export class Server extends Protocol<ServerContext> {
266264
};
267265
}
268266

269-
protected assertCapabilityForMethod(method: RequestMethod): void {
267+
protected assertCapabilityForMethod(method: string): void {
270268
switch (method) {
271269
case 'sampling/createMessage': {
272270
if (!this._clientCapabilities?.sampling) {
@@ -299,7 +297,7 @@ export class Server extends Protocol<ServerContext> {
299297
}
300298
}
301299

302-
protected assertNotificationCapability(method: NotificationMethod): void {
300+
protected assertNotificationCapability(method: string): void {
303301
switch (method) {
304302
case 'notifications/message': {
305303
if (!this._capabilities.logging) {

0 commit comments

Comments
 (0)