diff --git a/README.md b/README.md index 548fd2871..bcdf0d841 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,11 @@ The Model Context Protocol allows applications to provide context for LLMs in a ## Installation ```bash -npm install @modelcontextprotocol/sdk +npm install @modelcontextprotocol/sdk zod ``` +This SDK has a **required peer dependency** on `zod` for schema validation. The SDK internally imports from `zod/v4`, but maintains backwards compatibility with projects using Zod v3.25 or later. You can use either API in your code by importing from `zod/v3` or `zod/v4`: + ## Quick Start Let's create a simple MCP server that exposes a calculator tool and some data. Save the following as `server.ts`: @@ -58,7 +60,7 @@ Let's create a simple MCP server that exposes a calculator tool and some data. S import { McpServer, ResourceTemplate } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import express from 'express'; -import { z } from 'zod'; +import * as z from 'zod/v4'; // Create an MCP server const server = new McpServer({ @@ -130,7 +132,7 @@ app.listen(port, () => { }); ``` -Install the deps with `npm install @modelcontextprotocol/sdk express zod@3`, and run with `npx -y tsx server.ts`. +Install the deps with `npm install @modelcontextprotocol/sdk express zod`, and run with `npx -y tsx server.ts`. You can connect to it using any MCP client that supports streamable http, such as: @@ -477,7 +479,7 @@ MCP servers can request LLM completions from connected clients that support samp import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import express from 'express'; -import { z } from 'zod'; +import * as z from 'zod/v4'; const mcpServer = new McpServer({ name: 'tools-with-sample-server', @@ -561,7 +563,7 @@ For most use cases where session management isn't needed: import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import express from 'express'; -import { z } from 'zod'; +import * as z from 'zod/v4'; const app = express(); app.use(express.json()); @@ -796,7 +798,7 @@ A simple server demonstrating resources, tools, and prompts: ```typescript import { McpServer, ResourceTemplate } from '@modelcontextprotocol/sdk/server/mcp.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; const server = new McpServer({ name: 'echo-server', @@ -866,7 +868,7 @@ A more complex example showing database integration: import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import sqlite3 from 'sqlite3'; import { promisify } from 'util'; -import { z } from 'zod'; +import * as z from 'zod/v4'; const server = new McpServer({ name: 'sqlite-explorer', @@ -961,7 +963,7 @@ If you want to offer an initial set of tools/prompts/resources, but later add ad import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js'; import express from 'express'; -import { z } from 'zod'; +import * as z from 'zod/v4'; const server = new McpServer({ name: 'Dynamic Example', diff --git a/package-lock.json b/package-lock.json index 56512fd8d..e788506d3 100644 --- a/package-lock.json +++ b/package-lock.json @@ -20,8 +20,8 @@ "express-rate-limit": "^7.5.0", "pkce-challenge": "^5.0.0", "raw-body": "^3.0.0", - "zod": "^3.23.8", - "zod-to-json-schema": "^3.24.1" + "zod": "^3.25 || ^4.0", + "zod-to-json-schema": "^3.25.0" }, "devDependencies": { "@cfworker/json-schema": "^4.1.1", @@ -49,11 +49,15 @@ "node": ">=18" }, "peerDependencies": { - "@cfworker/json-schema": "^4.1.1" + "@cfworker/json-schema": "^4.1.1", + "zod": "^3.25 || ^4.0" }, "peerDependenciesMeta": { "@cfworker/json-schema": { "optional": true + }, + "zod": { + "optional": false } } }, @@ -680,136 +684,6 @@ "url": "https://github.com/sponsors/nzakas" } }, - "node_modules/@inquirer/ansi": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/@inquirer/ansi/-/ansi-1.0.2.tgz", - "integrity": "sha512-S8qNSZiYzFd0wAcyG5AXCvUHC5Sr7xpZ9wZ2py9XR88jUz8wooStVx5M6dRzczbBWjic9NP7+rY0Xi7qqK/aMQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "engines": { - "node": ">=18" - } - }, - "node_modules/@inquirer/confirm": { - "version": "5.1.20", - "resolved": "https://registry.npmjs.org/@inquirer/confirm/-/confirm-5.1.20.tgz", - "integrity": "sha512-HDGiWh2tyRZa0M1ZnEIUCQro25gW/mN8ODByicQrbR1yHx4hT+IOpozCMi5TgBtUdklLwRI2mv14eNpftDluEw==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "@inquirer/core": "^10.3.1", - "@inquirer/type": "^3.0.10" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "@types/node": ">=18" - }, - "peerDependenciesMeta": { - "@types/node": { - "optional": true - } - } - }, - "node_modules/@inquirer/core": { - "version": "10.3.1", - "resolved": "https://registry.npmjs.org/@inquirer/core/-/core-10.3.1.tgz", - "integrity": "sha512-hzGKIkfomGFPgxKmnKEKeA+uCYBqC+TKtRx5LgyHRCrF6S2MliwRIjp3sUaWwVzMp7ZXVs8elB0Tfe682Rpg4w==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "@inquirer/ansi": "^1.0.2", - "@inquirer/figures": "^1.0.15", - "@inquirer/type": "^3.0.10", - "cli-width": "^4.1.0", - "mute-stream": "^3.0.0", - "signal-exit": "^4.1.0", - "wrap-ansi": "^6.2.0", - "yoctocolors-cjs": "^2.1.3" - }, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "@types/node": ">=18" - }, - "peerDependenciesMeta": { - "@types/node": { - "optional": true - } - } - }, - "node_modules/@inquirer/core/node_modules/signal-exit": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", - "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", - "dev": true, - "license": "ISC", - "optional": true, - "peer": true, - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/@inquirer/core/node_modules/wrap-ansi": { - "version": "6.2.0", - "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-6.2.0.tgz", - "integrity": "sha512-r6lPcBGxZXlIcymEu7InxDMhdW0KDxpLgoFLcguasxCaJ/SOIZwINatK9KY/tf+ZrlywOKU0UDj3ATXUBfxJXA==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "ansi-styles": "^4.0.0", - "string-width": "^4.1.0", - "strip-ansi": "^6.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/@inquirer/figures": { - "version": "1.0.15", - "resolved": "https://registry.npmjs.org/@inquirer/figures/-/figures-1.0.15.tgz", - "integrity": "sha512-t2IEY+unGHOzAaVM5Xx6DEWKeXlDDcNPeDyUpsRc6CUhBfU3VQOEl+Vssh7VNp1dR8MdUJBWhuObjXCsVpjN5g==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "engines": { - "node": ">=18" - } - }, - "node_modules/@inquirer/type": { - "version": "3.0.10", - "resolved": "https://registry.npmjs.org/@inquirer/type/-/type-3.0.10.tgz", - "integrity": "sha512-BvziSRxfz5Ov8ch0z/n3oijRSEcEsHnhggm4xFZe93DHcUCTlutlq9Ox4SVENAfcRD22UQq7T/atg9Wr3k09eA==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "engines": { - "node": ">=18" - }, - "peerDependencies": { - "@types/node": ">=18" - }, - "peerDependenciesMeta": { - "@types/node": { - "optional": true - } - } - }, "node_modules/@jridgewell/sourcemap-codec": { "version": "1.5.5", "resolved": "https://registry.npmjs.org/@jridgewell/sourcemap-codec/-/sourcemap-codec-1.5.5.tgz", @@ -817,26 +691,6 @@ "dev": true, "license": "MIT" }, - "node_modules/@mswjs/interceptors": { - "version": "0.40.0", - "resolved": "https://registry.npmjs.org/@mswjs/interceptors/-/interceptors-0.40.0.tgz", - "integrity": "sha512-EFd6cVbHsgLa6wa4RljGj6Wk75qoHxUSyc5asLyyPSyuhIcdS2Q3Phw6ImS1q+CkALthJRShiYfKANcQMuMqsQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "@open-draft/deferred-promise": "^2.2.0", - "@open-draft/logger": "^0.3.0", - "@open-draft/until": "^2.0.0", - "is-node-process": "^1.2.0", - "outvariant": "^1.4.3", - "strict-event-emitter": "^0.5.1" - }, - "engines": { - "node": ">=18" - } - }, "node_modules/@noble/hashes": { "version": "1.8.0", "resolved": "https://registry.npmjs.org/@noble/hashes/-/hashes-1.8.0.tgz", @@ -885,37 +739,6 @@ "node": ">= 8" } }, - "node_modules/@open-draft/deferred-promise": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/@open-draft/deferred-promise/-/deferred-promise-2.2.0.tgz", - "integrity": "sha512-CecwLWx3rhxVQF6V4bAgPS5t+So2sTbPgAzafKkVizyi7tlwpcFpdFqq+wqF2OwNBmqFuu6tOyouTuxgpMfzmA==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, - "node_modules/@open-draft/logger": { - "version": "0.3.0", - "resolved": "https://registry.npmjs.org/@open-draft/logger/-/logger-0.3.0.tgz", - "integrity": "sha512-X2g45fzhxH238HKO4xbSr7+wBS8Fvw6ixhTDuvLd5mqh6bJJCFAPwU9mPDxbcrRtfxv4u5IHCEH77BmxvXmmxQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "is-node-process": "^1.2.0", - "outvariant": "^1.4.0" - } - }, - "node_modules/@open-draft/until": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/@open-draft/until/-/until-2.1.0.tgz", - "integrity": "sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/@paralleldrive/cuid2": { "version": "2.2.2", "resolved": "https://registry.npmjs.org/@paralleldrive/cuid2/-/cuid2-2.2.2.tgz", @@ -1421,15 +1244,6 @@ "@types/send": "*" } }, - "node_modules/@types/statuses": { - "version": "2.0.6", - "resolved": "https://registry.npmjs.org/@types/statuses/-/statuses-2.0.6.tgz", - "integrity": "sha512-xMAgYwceFhRA2zY+XbEA7mxYbA093wdiW8Vu6gZPGWy9cmOyU9XesH1tNcEWsKFd5Vzrqx5T3D38PWx1FIIXkA==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/@types/superagent": { "version": "8.1.9", "resolved": "https://registry.npmjs.org/@types/superagent/-/superagent-8.1.9.tgz", @@ -1501,6 +1315,7 @@ "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.11.0.tgz", "integrity": "sha512-lmt73NeHdy1Q/2ul295Qy3uninSqi6wQI18XwSpm8w0ZbQXUpjCAWP1Vlv/obudoBiIjJVjlztjQ+d/Md98Yxg==", "dev": true, + "peer": true, "dependencies": { "@typescript-eslint/scope-manager": "8.11.0", "@typescript-eslint/types": "8.11.0", @@ -1912,6 +1727,7 @@ "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.14.0.tgz", "integrity": "sha512-cl669nCJTZBsL97OF4kUQm5g5hC2uihk0NxY3WENAC0TYdILVkAyHymAntgxGkl7K+t0cXIrH5siy5S4XkFycA==", "dev": true, + "peer": true, "bin": { "acorn": "bin/acorn" }, @@ -1961,17 +1777,6 @@ } } }, - "node_modules/ansi-regex": { - "version": "5.0.1", - "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", - "integrity": "sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==", - "dev": true, - "optional": true, - "peer": true, - "engines": { - "node": ">=8" - } - }, "node_modules/ansi-styles": { "version": "4.3.0", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-4.3.0.tgz", @@ -2152,34 +1957,6 @@ "url": "https://github.com/chalk/chalk?sponsor=1" } }, - "node_modules/cli-width": { - "version": "4.1.0", - "resolved": "https://registry.npmjs.org/cli-width/-/cli-width-4.1.0.tgz", - "integrity": "sha512-ouuZd4/dm2Sw5Gmqy6bGyNNNe1qt9RpmxveLSO7KcgsTnU7RXfsw+/bukWGo1abgBiMAic068rclZsO4IWmmxQ==", - "dev": true, - "license": "ISC", - "optional": true, - "peer": true, - "engines": { - "node": ">= 12" - } - }, - "node_modules/cliui": { - "version": "8.0.1", - "resolved": "https://registry.npmjs.org/cliui/-/cliui-8.0.1.tgz", - "integrity": "sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==", - "dev": true, - "optional": true, - "peer": true, - "dependencies": { - "string-width": "^4.2.0", - "strip-ansi": "^6.0.1", - "wrap-ansi": "^7.0.0" - }, - "engines": { - "node": ">=12" - } - }, "node_modules/color-convert": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/color-convert/-/color-convert-2.0.1.tgz", @@ -2379,14 +2156,6 @@ "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", "license": "MIT" }, - "node_modules/emoji-regex": { - "version": "8.0.0", - "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", - "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", - "dev": true, - "optional": true, - "peer": true - }, "node_modules/encodeurl": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", @@ -2490,17 +2259,6 @@ "@esbuild/win32-x64": "0.25.0" } }, - "node_modules/escalade": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", - "integrity": "sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==", - "dev": true, - "optional": true, - "peer": true, - "engines": { - "node": ">=6" - } - }, "node_modules/escape-html": { "version": "1.0.3", "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", @@ -2524,6 +2282,7 @@ "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.13.0.tgz", "integrity": "sha512-EYZK6SX6zjFHST/HRytOdA/zE72Cq/bfw45LSyuwrdvcclb/gqV8RRQxywOBEWO2+WDpva6UZa4CcDeJKzUCFA==", "dev": true, + "peer": true, "dependencies": { "@eslint-community/eslint-utils": "^4.2.0", "@eslint-community/regexpp": "^4.11.0", @@ -3122,17 +2881,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/get-caller-file": { - "version": "2.0.5", - "resolved": "https://registry.npmjs.org/get-caller-file/-/get-caller-file-2.0.5.tgz", - "integrity": "sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==", - "dev": true, - "optional": true, - "peer": true, - "engines": { - "node": "6.* || 8.* || >= 10.*" - } - }, "node_modules/get-intrinsic": { "version": "1.2.7", "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.2.7.tgz", @@ -3224,18 +2972,6 @@ "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", "dev": true }, - "node_modules/graphql": { - "version": "16.12.0", - "resolved": "https://registry.npmjs.org/graphql/-/graphql-16.12.0.tgz", - "integrity": "sha512-DKKrynuQRne0PNpEbzuEdHlYOMksHSUI8Zc9Unei5gTsMNA2/vMpoMz/yKba50pejK56qj98qM0SjYxAKi13gQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "engines": { - "node": "^12.22.0 || ^14.16.0 || ^16.0.0 || >=17.0.0" - } - }, "node_modules/has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", @@ -3284,15 +3020,6 @@ "node": ">= 0.4" } }, - "node_modules/headers-polyfill": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/headers-polyfill/-/headers-polyfill-4.0.3.tgz", - "integrity": "sha512-IScLbePpkvO846sIwOtOTDjutRMWdXdJmXdMvk6gCBHxFO8d+QKOQedyZSxFTTFYRSmlgSTDtXqqq4pcenBXLQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/http-errors": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz", @@ -3376,17 +3103,6 @@ "node": ">=0.10.0" } }, - "node_modules/is-fullwidth-code-point": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", - "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", - "dev": true, - "optional": true, - "peer": true, - "engines": { - "node": ">=8" - } - }, "node_modules/is-glob": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", @@ -3399,15 +3115,6 @@ "node": ">=0.10.0" } }, - "node_modules/is-node-process": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/is-node-process/-/is-node-process-1.2.0.tgz", - "integrity": "sha512-Vg4o6/fqPxIjtxgUH5QLJhwZ7gW5diGCVlXpuUfELC62CuxM1iHcRe51f2W1FDy04Ai4KJkagKjx3XaqyfRKXw==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", @@ -3622,113 +3329,6 @@ "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz", "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==" }, - "node_modules/msw": { - "version": "2.12.1", - "resolved": "https://registry.npmjs.org/msw/-/msw-2.12.1.tgz", - "integrity": "sha512-arzsi9IZjjByiEw21gSUP82qHM8zkV69nNpWV6W4z72KiLvsDWoOp678ORV6cNfU/JGhlX0SsnD4oXo9gI6I2A==", - "dev": true, - "hasInstallScript": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "@inquirer/confirm": "^5.0.0", - "@mswjs/interceptors": "^0.40.0", - "@open-draft/deferred-promise": "^2.2.0", - "@types/statuses": "^2.0.4", - "cookie": "^1.0.2", - "graphql": "^16.8.1", - "headers-polyfill": "^4.0.2", - "is-node-process": "^1.2.0", - "outvariant": "^1.4.3", - "path-to-regexp": "^6.3.0", - "picocolors": "^1.1.1", - "rettime": "^0.7.0", - "statuses": "^2.0.2", - "strict-event-emitter": "^0.5.1", - "tough-cookie": "^6.0.0", - "type-fest": "^4.26.1", - "until-async": "^3.0.2", - "yargs": "^17.7.2" - }, - "bin": { - "msw": "cli/index.js" - }, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/mswjs" - }, - "peerDependencies": { - "typescript": ">= 4.8.x" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, - "node_modules/msw/node_modules/cookie": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/cookie/-/cookie-1.0.2.tgz", - "integrity": "sha512-9Kr/j4O16ISv8zBBhJoi4bXOYNTkFLOqSL3UDB0njXxCXNezjeyVrJyGOWtgfs/q2km1gwBcfH8q1yEGoMYunA==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "engines": { - "node": ">=18" - } - }, - "node_modules/msw/node_modules/path-to-regexp": { - "version": "6.3.0", - "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-6.3.0.tgz", - "integrity": "sha512-Yhpw4T9C6hPpgPeA28us07OJeqZ5EzQTkbfwuhsUg0c237RomFoETJgmp2sa3F/41gfLE6G5cqcYwznmeEeOlQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, - "node_modules/msw/node_modules/statuses": { - "version": "2.0.2", - "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz", - "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "engines": { - "node": ">= 0.8" - } - }, - "node_modules/msw/node_modules/type-fest": { - "version": "4.41.0", - "resolved": "https://registry.npmjs.org/type-fest/-/type-fest-4.41.0.tgz", - "integrity": "sha512-TeTSQ6H5YHvpqVwBRcnLDCBnDOHWYu7IvGbHT6N8AOymcr9PJGjc1GTtiWZTYg0NCgYwvnYWEkVChQAr9bjfwA==", - "dev": true, - "license": "(MIT OR CC0-1.0)", - "optional": true, - "peer": true, - "engines": { - "node": ">=16" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/mute-stream": { - "version": "3.0.0", - "resolved": "https://registry.npmjs.org/mute-stream/-/mute-stream-3.0.0.tgz", - "integrity": "sha512-dkEJPVvun4FryqBmZ5KhDo0K9iDXAwn08tMLDinNdRBNPcYEDiWYysLcc6k3mjTMlbP9KyylvRpd4wFtwrT9rw==", - "dev": true, - "license": "ISC", - "optional": true, - "peer": true, - "engines": { - "node": "^20.17.0 || >=22.9.0" - } - }, "node_modules/nanoid": { "version": "3.3.11", "resolved": "https://registry.npmjs.org/nanoid/-/nanoid-3.3.11.tgz", @@ -3821,15 +3421,6 @@ "node": ">= 0.8.0" } }, - "node_modules/outvariant": { - "version": "1.4.3", - "resolved": "https://registry.npmjs.org/outvariant/-/outvariant-1.4.3.tgz", - "integrity": "sha512-+Sl2UErvtsoajRDKCE5/dBz4DIvHXQQnAxtQTF04OJxY0+DyZXSo5P5Bb7XYWOh81syohlYL24hbDwxedPUJCA==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/p-limit": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/p-limit/-/p-limit-3.1.0.tgz", @@ -4086,17 +3677,6 @@ "node": ">=0.10.0" } }, - "node_modules/require-directory": { - "version": "2.1.1", - "resolved": "https://registry.npmjs.org/require-directory/-/require-directory-2.1.1.tgz", - "integrity": "sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==", - "dev": true, - "optional": true, - "peer": true, - "engines": { - "node": ">=0.10.0" - } - }, "node_modules/require-from-string": { "version": "2.0.2", "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", @@ -4124,15 +3704,6 @@ "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" } }, - "node_modules/rettime": { - "version": "0.7.0", - "resolved": "https://registry.npmjs.org/rettime/-/rettime-0.7.0.tgz", - "integrity": "sha512-LPRKoHnLKd/r3dVxcwO7vhCW+orkOGj9ViueosEBK6ie89CijnfRlhaDhHq/3Hxu4CkWQtxwlBG0mzTQY6uQjw==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/reusify": { "version": "1.0.4", "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.0.4.tgz", @@ -4450,45 +4021,6 @@ "dev": true, "license": "MIT" }, - "node_modules/strict-event-emitter": { - "version": "0.5.1", - "resolved": "https://registry.npmjs.org/strict-event-emitter/-/strict-event-emitter-0.5.1.tgz", - "integrity": "sha512-vMgjE/GGEPEFnhFub6pa4FmJBRBVOLpIII2hvCZ8Kzb7K0hlHo7mQv6xYrBvCL2LtAIBwFUK8wvuJgTVSQ5MFQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, - "node_modules/string-width": { - "version": "4.2.3", - "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", - "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", - "dev": true, - "optional": true, - "peer": true, - "dependencies": { - "emoji-regex": "^8.0.0", - "is-fullwidth-code-point": "^3.0.0", - "strip-ansi": "^6.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/strip-ansi": { - "version": "6.0.1", - "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", - "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", - "dev": true, - "optional": true, - "peer": true, - "dependencies": { - "ansi-regex": "^5.0.1" - }, - "engines": { - "node": ">=8" - } - }, "node_modules/strip-json-comments": { "version": "3.1.1", "resolved": "https://registry.npmjs.org/strip-json-comments/-/strip-json-comments-3.1.1.tgz", @@ -4609,6 +4141,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -4626,30 +4159,6 @@ "node": ">=14.0.0" } }, - "node_modules/tldts": { - "version": "7.0.17", - "resolved": "https://registry.npmjs.org/tldts/-/tldts-7.0.17.tgz", - "integrity": "sha512-Y1KQBgDd/NUc+LfOtKS6mNsC9CCaH+m2P1RoIZy7RAPo3C3/t8X45+zgut31cRZtZ3xKPjfn3TkGTrctC2TQIQ==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "dependencies": { - "tldts-core": "^7.0.17" - }, - "bin": { - "tldts": "bin/cli.js" - } - }, - "node_modules/tldts-core": { - "version": "7.0.17", - "resolved": "https://registry.npmjs.org/tldts-core/-/tldts-core-7.0.17.tgz", - "integrity": "sha512-DieYoGrP78PWKsrXr8MZwtQ7GLCUeLxihtjC1jZsW1DnvSMdKPitJSe8OSYDM2u5H6g3kWJZpePqkp43TfLh0g==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true - }, "node_modules/to-regex-range": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", @@ -4670,21 +4179,6 @@ "node": ">=0.6" } }, - "node_modules/tough-cookie": { - "version": "6.0.0", - "resolved": "https://registry.npmjs.org/tough-cookie/-/tough-cookie-6.0.0.tgz", - "integrity": "sha512-kXuRi1mtaKMrsLUxz3sQYvVl37B0Ns6MzfrtV5DvJceE9bPyspOqk9xxv7XbZWcfLWbFmm997vl83qUWVJA64w==", - "dev": true, - "license": "BSD-3-Clause", - "optional": true, - "peer": true, - "dependencies": { - "tldts": "^7.0.5" - }, - "engines": { - "node": ">=16" - } - }, "node_modules/ts-api-utils": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-1.3.0.tgz", @@ -4703,6 +4197,7 @@ "integrity": "sha512-4H8vUNGNjQ4V2EOoGw005+c+dGuPSnhpPBPHBtsZdGZBk/iJb4kguGlPWaZTZ3q5nMtFOEsY0nRDlh9PJyd6SQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "~0.25.0", "get-tsconfig": "^4.7.5" @@ -4748,6 +4243,7 @@ "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.6.3.tgz", "integrity": "sha512-hjcS1mhfuyi4WW8IWtjP7brDrG2cuDZukyrYrSauoXGNgx0S7zceP07adYkJycEr56BOUTNPzbInooiN3fn1qw==", "dev": true, + "peer": true, "bin": { "tsc": "bin/tsc", "tsserver": "bin/tsserver" @@ -4794,18 +4290,6 @@ "node": ">= 0.8" } }, - "node_modules/until-async": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/until-async/-/until-async-3.0.2.tgz", - "integrity": "sha512-IiSk4HlzAMqTUseHHe3VhIGyuFmN90zMTpD3Z3y8jeQbzLIq500MVM7Jq2vUAnTKAFPJrqwkzr6PoTcPhGcOiw==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "funding": { - "url": "https://github.com/sponsors/kettanaito" - } - }, "node_modules/uri-js": { "version": "4.4.1", "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz", @@ -4961,6 +4445,7 @@ "integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==", "dev": true, "license": "MIT", + "peer": true, "engines": { "node": ">=12" }, @@ -4974,6 +4459,7 @@ "integrity": "sha512-BxAKBWmIbrDgrokdGZH1IgkIk/5mMHDreLDmCJ0qpyJaAteP8NvMhkwr/ZCQNqNH97bw/dANTE9PDzqwJghfMQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "esbuild": "^0.25.0", "fdir": "^6.5.0", @@ -5083,25 +4569,6 @@ "node": ">=0.10.0" } }, - "node_modules/wrap-ansi": { - "version": "7.0.0", - "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", - "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", - "dev": true, - "optional": true, - "peer": true, - "dependencies": { - "ansi-styles": "^4.0.0", - "string-width": "^4.1.0", - "strip-ansi": "^6.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/wrap-ansi?sponsor=1" - } - }, "node_modules/wrappy": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", @@ -5128,48 +4595,6 @@ } } }, - "node_modules/y18n": { - "version": "5.0.8", - "resolved": "https://registry.npmjs.org/y18n/-/y18n-5.0.8.tgz", - "integrity": "sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==", - "dev": true, - "optional": true, - "peer": true, - "engines": { - "node": ">=10" - } - }, - "node_modules/yargs": { - "version": "17.7.2", - "resolved": "https://registry.npmjs.org/yargs/-/yargs-17.7.2.tgz", - "integrity": "sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==", - "dev": true, - "optional": true, - "peer": true, - "dependencies": { - "cliui": "^8.0.1", - "escalade": "^3.1.1", - "get-caller-file": "^2.0.5", - "require-directory": "^2.1.1", - "string-width": "^4.2.3", - "y18n": "^5.0.5", - "yargs-parser": "^21.1.1" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/yargs-parser": { - "version": "21.1.1", - "resolved": "https://registry.npmjs.org/yargs-parser/-/yargs-parser-21.1.1.tgz", - "integrity": "sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==", - "dev": true, - "optional": true, - "peer": true, - "engines": { - "node": ">=12" - } - }, "node_modules/yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", @@ -5182,37 +4607,23 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/yoctocolors-cjs": { - "version": "2.1.3", - "resolved": "https://registry.npmjs.org/yoctocolors-cjs/-/yoctocolors-cjs-2.1.3.tgz", - "integrity": "sha512-U/PBtDf35ff0D8X8D0jfdzHYEPFxAI7jJlxZXwCSez5M3190m+QobIfh+sWDWSHMCWWJN2AWamkegn6vr6YBTw==", - "dev": true, - "license": "MIT", - "optional": true, - "peer": true, - "engines": { - "node": ">=18" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, "node_modules/zod": { - "version": "3.24.1", - "resolved": "https://registry.npmjs.org/zod/-/zod-3.24.1.tgz", - "integrity": "sha512-muH7gBL9sI1nciMZV67X5fTKKBLtwpZ5VBp1vsOQzj1MhrBZ4wlVCm3gedKZWLp0Oyel8sIGfeiz54Su+OVT+A==", + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", "license": "MIT", + "peer": true, "funding": { "url": "https://github.com/sponsors/colinhacks" } }, "node_modules/zod-to-json-schema": { - "version": "3.24.1", - "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.1.tgz", - "integrity": "sha512-3h08nf3Vw3Wl3PK+q3ow/lIil81IT2Oa7YpQyUUDsEWbXveMesdfK1xBd2RhCkynwZndAxixji/7SYJJowr62w==", + "version": "3.25.0", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.0.tgz", + "integrity": "sha512-HvWtU2UG41LALjajJrML6uQejQhNJx+JBO9IflpSja4R03iNWfKXrj6W2h7ljuLyc1nKS+9yDyL/9tD1U/yBnQ==", "license": "ISC", "peerDependencies": { - "zod": "^3.24.1" + "zod": "^3.25 || ^4" } } } diff --git a/package.json b/package.json index 0ea194384..b103f4a6e 100644 --- a/package.json +++ b/package.json @@ -89,15 +89,19 @@ "express-rate-limit": "^7.5.0", "pkce-challenge": "^5.0.0", "raw-body": "^3.0.0", - "zod": "^3.23.8", - "zod-to-json-schema": "^3.24.1" + "zod": "^3.25 || ^4.0", + "zod-to-json-schema": "^3.25.0" }, "peerDependencies": { - "@cfworker/json-schema": "^4.1.1" + "@cfworker/json-schema": "^4.1.1", + "zod": "^3.25 || ^4.0" }, "peerDependenciesMeta": { "@cfworker/json-schema": { "optional": true + }, + "zod": { + "optional": false } }, "devDependencies": { diff --git a/src/client/index.test.ts b/src/client/index.test.ts index 8534842ee..2143f603d 100644 --- a/src/client/index.test.ts +++ b/src/client/index.test.ts @@ -2,7 +2,7 @@ /* eslint-disable no-constant-binary-expression */ /* eslint-disable @typescript-eslint/no-unused-expressions */ import { Client, getSupportedElicitationModes } from './index.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { RequestSchema, NotificationSchema, diff --git a/src/client/index.ts b/src/client/index.ts index f2864982a..694ae4a1a 100644 --- a/src/client/index.ts +++ b/src/client/index.ts @@ -42,7 +42,15 @@ import { } from '../types.js'; import { AjvJsonSchemaValidator } from '../validation/ajv-provider.js'; import type { JsonSchemaType, JsonSchemaValidator, jsonSchemaValidator } from '../validation/types.js'; -import { ZodLiteral, ZodObject, z } from 'zod'; +import { + AnyObjectSchema, + SchemaOutput, + getObjectShape, + isZ4Schema, + safeParse, + type ZodV3Internal, + type ZodV4Internal +} from '../server/zod-compat.js'; import type { RequestHandlerExtra } from '../shared/protocol.js'; /** @@ -216,26 +224,46 @@ export class Client< /** * Override request handler registration to enforce client-side validation for elicitation. */ - public override setRequestHandler< - T extends ZodObject<{ - method: ZodLiteral; - }> - >( + public override setRequestHandler( requestSchema: T, handler: ( - request: z.infer, + request: SchemaOutput, extra: RequestHandlerExtra ) => ClientResult | ResultT | Promise ): void { - const method = requestSchema.shape.method.value; + const shape = getObjectShape(requestSchema); + const methodSchema = shape?.method; + if (!methodSchema) { + throw new Error('Schema is missing a method literal'); + } + + // Extract literal value using type-safe property access + let methodValue: unknown; + if (isZ4Schema(methodSchema)) { + const v4Schema = methodSchema as unknown as ZodV4Internal; + const v4Def = v4Schema._zod?.def; + methodValue = v4Def?.value ?? v4Schema.value; + } else { + const v3Schema = methodSchema as unknown as ZodV3Internal; + const legacyDef = v3Schema._def; + methodValue = legacyDef?.value ?? v3Schema.value; + } + + if (typeof methodValue !== 'string') { + throw new Error('Schema method literal must be a string'); + } + const method = methodValue; if (method === 'elicitation/create') { const wrappedHandler = async ( - request: z.infer, + request: SchemaOutput, extra: RequestHandlerExtra ): Promise => { - const validatedRequest = ElicitRequestSchema.safeParse(request); + const validatedRequest = safeParse(ElicitRequestSchema, request); if (!validatedRequest.success) { - throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${validatedRequest.error.message}`); + // Type guard: if success is false, error is guaranteed to exist + const errorMessage = + validatedRequest.error instanceof Error ? validatedRequest.error.message : String(validatedRequest.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation request: ${errorMessage}`); } const { params } = validatedRequest.data; @@ -251,9 +279,12 @@ export class Client< const result = await Promise.resolve(handler(request, extra)); - const validationResult = ElicitResultSchema.safeParse(result); + const validationResult = safeParse(ElicitResultSchema, result); if (!validationResult.success) { - throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${validationResult.error.message}`); + // Type guard: if success is false, error is guaranteed to exist + const errorMessage = + validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error); + throw new McpError(ErrorCode.InvalidParams, `Invalid elicitation result: ${errorMessage}`); } const validatedResult = validationResult.data; diff --git a/src/client/v3/index.v3.test.ts b/src/client/v3/index.v3.test.ts new file mode 100644 index 000000000..78a53eea0 --- /dev/null +++ b/src/client/v3/index.v3.test.ts @@ -0,0 +1,1591 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable no-constant-binary-expression */ +/* eslint-disable @typescript-eslint/no-unused-expressions */ +import { Client, getSupportedElicitationModes } from '../index.js'; +import * as z from 'zod/v3'; +import { + RequestSchema, + NotificationSchema, + ResultSchema, + LATEST_PROTOCOL_VERSION, + SUPPORTED_PROTOCOL_VERSIONS, + InitializeRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + CallToolRequestSchema, + CreateMessageRequestSchema, + ElicitRequestSchema, + ListRootsRequestSchema, + ErrorCode +} from '../../types.js'; +import { Transport } from '../../shared/transport.js'; +import { Server } from '../../server/index.js'; +import { InMemoryTransport } from '../../inMemory.js'; + +/*** + * Test: Initialize with Matching Protocol Version + */ +test('should initialize with matching protocol version', async () => { + const clientTransport: Transport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + send: vi.fn().mockImplementation(message => { + if (message.method === 'initialize') { + clientTransport.onmessage?.({ + jsonrpc: '2.0', + id: message.id, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { + name: 'test', + version: '1.0' + }, + instructions: 'test instructions' + } + }); + } + return Promise.resolve(); + }) + }; + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + await client.connect(clientTransport); + + // Should have sent initialize with latest version + expect(clientTransport.send).toHaveBeenCalledWith( + expect.objectContaining({ + method: 'initialize', + params: expect.objectContaining({ + protocolVersion: LATEST_PROTOCOL_VERSION + }) + }), + expect.objectContaining({ + relatedRequestId: undefined + }) + ); + + // Should have the instructions returned + expect(client.getInstructions()).toEqual('test instructions'); +}); + +/*** + * Test: Initialize with Supported Older Protocol Version + */ +test('should initialize with supported older protocol version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const clientTransport: Transport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + send: vi.fn().mockImplementation(message => { + if (message.method === 'initialize') { + clientTransport.onmessage?.({ + jsonrpc: '2.0', + id: message.id, + result: { + protocolVersion: OLD_VERSION, + capabilities: {}, + serverInfo: { + name: 'test', + version: '1.0' + } + } + }); + } + return Promise.resolve(); + }) + }; + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + await client.connect(clientTransport); + + // Connection should succeed with the older version + expect(client.getServerVersion()).toEqual({ + name: 'test', + version: '1.0' + }); + + // Expect no instructions + expect(client.getInstructions()).toBeUndefined(); +}); + +/*** + * Test: Reject Unsupported Protocol Version + */ +test('should reject unsupported protocol version', async () => { + const clientTransport: Transport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + send: vi.fn().mockImplementation(message => { + if (message.method === 'initialize') { + clientTransport.onmessage?.({ + jsonrpc: '2.0', + id: message.id, + result: { + protocolVersion: 'invalid-version', + capabilities: {}, + serverInfo: { + name: 'test', + version: '1.0' + } + } + }); + } + return Promise.resolve(); + }) + }; + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + await expect(client.connect(clientTransport)).rejects.toThrow("Server's protocol version is not supported: invalid-version"); + + expect(clientTransport.close).toHaveBeenCalled(); +}); + +/*** + * Test: Connect New Client to Old Supported Server Version + */ +test('should connect new client to old, supported server version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: OLD_VERSION, + capabilities: { + resources: {}, + tools: {} + }, + serverInfo: { + name: 'old server', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'new client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(client.getServerVersion()).toEqual({ + name: 'old server', + version: '1.0' + }); +}); + +/*** + * Test: Version Negotiation with Old Client and Newer Server + */ +test('should negotiate version when client is old, and newer server supports its version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const server = new Server( + { + name: 'new server', + version: '1.0' + }, + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: { + resources: {}, + tools: {} + }, + serverInfo: { + name: 'new server', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'old client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(client.getServerVersion()).toEqual({ + name: 'new server', + version: '1.0' + }); +}); + +/*** + * Test: Throw when Old Client and Server Version Mismatch + */ +test("should throw when client is old, and server doesn't support its version", async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + const FUTURE_VERSION = 'FUTURE_VERSION'; + const server = new Server( + { + name: 'new server', + version: '1.0' + }, + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: FUTURE_VERSION, + capabilities: { + resources: {}, + tools: {} + }, + serverInfo: { + name: 'new server', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'old client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + + await Promise.all([ + expect(client.connect(clientTransport)).rejects.toThrow("Server's protocol version is not supported: FUTURE_VERSION"), + server.connect(serverTransport) + ]); +}); + +/*** + * Test: Respect Server Capabilities + */ +test('should respect server capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + resources: {}, + tools: {} + } + } + ); + + server.setRequestHandler(InitializeRequestSchema, _request => ({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: { + resources: {}, + tools: {} + }, + serverInfo: { + name: 'test', + version: '1.0' + } + })); + + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + + server.setRequestHandler(ListToolsRequestSchema, () => ({ + tools: [] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + }, + enforceStrictCapabilities: true + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server supports resources and tools, but not prompts + expect(client.getServerCapabilities()).toEqual({ + resources: {}, + tools: {} + }); + + // These should work + await expect(client.listResources()).resolves.not.toThrow(); + await expect(client.listTools()).resolves.not.toThrow(); + + // These should throw because prompts, logging, and completions are not supported + await expect(client.listPrompts()).rejects.toThrow('Server does not support prompts'); + await expect(client.setLoggingLevel('error')).rejects.toThrow('Server does not support logging'); + await expect( + client.complete({ + ref: { type: 'ref/prompt', name: 'test' }, + argument: { name: 'test', value: 'test' } + }) + ).rejects.toThrow('Server does not support completions'); +}); + +/*** + * Test: Respect Client Notification Capabilities + */ +test('should respect client notification capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: {} + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + roots: { + listChanged: true + } + } + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // This should work because the client has the roots.listChanged capability + await expect(client.sendRootsListChanged()).resolves.not.toThrow(); + + // Create a new client without the roots.listChanged capability + const clientWithoutCapability = new Client( + { + name: 'test client without capability', + version: '1.0' + }, + { + capabilities: {}, + enforceStrictCapabilities: true + } + ); + + await clientWithoutCapability.connect(clientTransport); + + // This should throw because the client doesn't have the roots.listChanged capability + await expect(clientWithoutCapability.sendRootsListChanged()).rejects.toThrow(/^Client does not support/); +}); + +/*** + * Test: Respect Server Notification Capabilities + */ +test('should respect server notification capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + logging: {}, + resources: { + listChanged: true + } + } + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: {} + } + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // These should work because the server has the corresponding capabilities + await expect(server.sendLoggingMessage({ level: 'info', data: 'Test' })).resolves.not.toThrow(); + await expect(server.sendResourceListChanged()).resolves.not.toThrow(); + + // This should throw because the server doesn't have the tools capability + await expect(server.sendToolListChanged()).rejects.toThrow('Server does not support notifying of tool list changes'); +}); + +/*** + * Test: Only Allow setRequestHandler for Declared Capabilities + */ +test('should only allow setRequestHandler for declared capabilities', () => { + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + // This should work because sampling is a declared capability + expect(() => { + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: 'test-model', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + })); + }).not.toThrow(); + + // This should throw because roots listing is not a declared capability + expect(() => { + client.setRequestHandler(ListRootsRequestSchema, () => ({})); + }).toThrow('Client does not support roots capability'); +}); + +test('should allow setRequestHandler for declared elicitation capability', () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // This should work because elicitation is a declared capability + expect(() => { + client.setRequestHandler(ElicitRequestSchema, () => ({ + action: 'accept', + content: { + username: 'test-user', + confirmed: true + } + })); + }).not.toThrow(); + + // This should throw because sampling is not a declared capability + expect(() => { + client.setRequestHandler(CreateMessageRequestSchema, () => ({ + model: 'test-model', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + })); + }).toThrow('Client does not support sampling capability'); +}); + +test('should accept form-mode elicitation request when client advertises empty elicitation object (back-compat)', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Set up client handler for form-mode elicitation + client.setRequestHandler(ElicitRequestSchema, request => { + expect(request.params.mode).toBe('form'); + return { + action: 'accept', + content: { + username: 'test-user', + confirmed: true + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Server should be able to send form-mode elicitation request + // This works because getSupportedElicitationModes defaults to form mode + // when neither form nor url are explicitly declared + const result = await server.elicitInput({ + mode: 'form', + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { + type: 'string', + title: 'Username', + description: 'Your username' + }, + confirmed: { + type: 'boolean', + title: 'Confirm', + description: 'Please confirm', + default: false + } + }, + required: ['username'] + } + }); + + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ + username: 'test-user', + confirmed: true + }); +}); + +test('should reject form-mode elicitation when client only supports URL mode', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: { + url: {} + } + } + } + ); + + const handler = vi.fn().mockResolvedValue({ + action: 'cancel' + }); + client.setRequestHandler(ElicitRequestSchema, handler); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + let resolveResponse: ((message: unknown) => void) | undefined; + const responsePromise = new Promise(resolve => { + resolveResponse = resolve; + }); + + serverTransport.onmessage = async message => { + if ('method' in message) { + if (message.method === 'initialize') { + if (!('id' in message) || message.id === undefined) { + throw new Error('Expected initialize request to include an id'); + } + const messageId = message.id; + await serverTransport.send({ + jsonrpc: '2.0', + id: messageId, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + } + }); + } else if (message.method === 'notifications/initialized') { + // ignore + } + } else { + resolveResponse?.(message); + } + }; + + await client.connect(clientTransport); + + // Server shouldn't send this, because the client capabilities + // only advertised URL mode. Test that it's rejected by the client: + const requestId = 1; + await serverTransport.send({ + jsonrpc: '2.0', + id: requestId, + method: 'elicitation/create', + params: { + mode: 'form', + message: 'Provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { + type: 'string' + } + } + } + } + }); + + const response = (await responsePromise) as { id: number; error: { code: number; message: string } }; + + expect(response.id).toBe(requestId); + expect(response.error.code).toBe(ErrorCode.InvalidParams); + expect(response.error.message).toContain('Client does not support form-mode elicitation requests'); + expect(handler).not.toHaveBeenCalled(); + + await client.close(); +}); + +test('should reject URL-mode elicitation when client only supports form mode', async () => { + const client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: { + form: {} + } + } + } + ); + + const handler = vi.fn().mockResolvedValue({ + action: 'cancel' + }); + client.setRequestHandler(ElicitRequestSchema, handler); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + let resolveResponse: ((message: unknown) => void) | undefined; + const responsePromise = new Promise(resolve => { + resolveResponse = resolve; + }); + + serverTransport.onmessage = async message => { + if ('method' in message) { + if (message.method === 'initialize') { + if (!('id' in message) || message.id === undefined) { + throw new Error('Expected initialize request to include an id'); + } + const messageId = message.id; + await serverTransport.send({ + jsonrpc: '2.0', + id: messageId, + result: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + } + }); + } else if (message.method === 'notifications/initialized') { + // ignore + } + } else { + resolveResponse?.(message); + } + }; + + await client.connect(clientTransport); + + // Server shouldn't send this, because the client capabilities + // only advertised form mode. Test that it's rejected by the client: + const requestId = 2; + await serverTransport.send({ + jsonrpc: '2.0', + id: requestId, + method: 'elicitation/create', + params: { + mode: 'url', + message: 'Open the authorization page', + elicitationId: 'elicitation-123', + url: 'https://example.com/authorize' + } + }); + + const response = (await responsePromise) as { id: number; error: { code: number; message: string } }; + + expect(response.id).toBe(requestId); + expect(response.error.code).toBe(ErrorCode.InvalidParams); + expect(response.error.message).toContain('Client does not support URL-mode elicitation requests'); + expect(handler).not.toHaveBeenCalled(); + + await client.close(); +}); + +test('should apply defaults for form-mode elicitation when applyDefaults is enabled', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: { + form: { + applyDefaults: true + } + } + } + } + ); + + client.setRequestHandler(ElicitRequestSchema, request => { + expect(request.params.mode).toBe('form'); + return { + action: 'accept', + content: {} + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const result = await server.elicitInput({ + mode: 'form', + message: 'Please confirm your preferences', + requestedSchema: { + type: 'object', + properties: { + confirmed: { + type: 'boolean', + default: true + } + } + } + }); + + expect(result.action).toBe('accept'); + expect(result.content).toEqual({ + confirmed: true + }); + + await client.close(); +}); + +/*** + * Test: Type Checking + * Test that custom request/notification/result schemas can be used with the Client class. + */ +test('should typecheck', () => { + const GetWeatherRequestSchema = z.object({ + ...RequestSchema.shape, + method: z.literal('weather/get'), + params: z.object({ + city: z.string() + }) + }); + + const GetForecastRequestSchema = z.object({ + ...RequestSchema.shape, + method: z.literal('weather/forecast'), + params: z.object({ + city: z.string(), + days: z.number() + }) + }); + + const WeatherForecastNotificationSchema = z.object({ + ...NotificationSchema.shape, + method: z.literal('weather/alert'), + params: z.object({ + severity: z.enum(['warning', 'watch']), + message: z.string() + }) + }); + + const WeatherRequestSchema = GetWeatherRequestSchema.or(GetForecastRequestSchema); + const WeatherNotificationSchema = WeatherForecastNotificationSchema; + const WeatherResultSchema = z.object({ + ...ResultSchema.shape, + _meta: z.record(z.string(), z.unknown()).optional(), + temperature: z.number(), + conditions: z.string() + }); + + type WeatherRequest = z.infer; + type WeatherNotification = z.infer; + type WeatherResult = z.infer; + + // Create a typed Client for weather data + const weatherClient = new Client( + { + name: 'WeatherClient', + version: '1.0.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + // Typecheck that only valid weather requests/notifications/results are allowed + false && + weatherClient.request( + { + method: 'weather/get', + params: { + city: 'Seattle' + } + }, + WeatherResultSchema + ); + + false && + weatherClient.notification({ + method: 'weather/alert', + params: { + severity: 'warning', + message: 'Storm approaching' + } + }); +}); + +/*** + * Test: Handle Client Cancelling a Request + */ +test('should handle client cancelling a request', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + resources: {} + } + } + ); + + // Set up server to delay responding to listResources + server.setRequestHandler(ListResourcesRequestSchema, async (request, extra) => { + await new Promise(resolve => setTimeout(resolve, 1000)); + return { + resources: [] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Set up abort controller + const controller = new AbortController(); + + // Issue request but cancel it immediately + const listResourcesPromise = client.listResources(undefined, { + signal: controller.signal + }); + controller.abort('Cancelled by test'); + + // Request should be rejected + await expect(listResourcesPromise).rejects.toBe('Cancelled by test'); +}); + +/*** + * Test: Handle Request Timeout + */ +test('should handle request timeout', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + resources: {} + } + } + ); + + // Set up server with a delayed response + server.setRequestHandler(ListResourcesRequestSchema, async (_request, extra) => { + const timer = new Promise(resolve => { + const timeout = setTimeout(resolve, 100); + extra.signal.addEventListener('abort', () => clearTimeout(timeout)); + }); + + await timer; + return { + resources: [] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: {} + } + ); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Request with 0 msec timeout should fail immediately + await expect(client.listResources(undefined, { timeout: 0 })).rejects.toMatchObject({ + code: ErrorCode.RequestTimeout + }); +}); + +describe('outputSchema validation', () => { + /*** + * Test: Validate structuredContent Against outputSchema + */ + test('should validate structuredContent against outputSchema', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' }, + count: { type: 'number' } + }, + required: ['result', 'count'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + return { + structuredContent: { result: 'success', count: 42 } + }; + } + throw new Error('Unknown tool'); + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + // Call the tool - should validate successfully + const result = await client.callTool({ name: 'test-tool' }); + expect(result.structuredContent).toEqual({ result: 'success', count: 42 }); + }); + + /*** + * Test: Throw Error when structuredContent Does Not Match Schema + */ + test('should throw error when structuredContent does not match schema', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' }, + count: { type: 'number' } + }, + required: ['result', 'count'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Return invalid structured content (count is string instead of number) + return { + structuredContent: { result: 'success', count: 'not a number' } + }; + } + throw new Error('Unknown tool'); + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + // Call the tool - should throw validation error + await expect(client.callTool({ name: 'test-tool' })).rejects.toThrow(/Structured content does not match the tool's output schema/); + }); + + /*** + * Test: Throw Error when Tool with outputSchema Returns No structuredContent + */ + test('should throw error when tool with outputSchema returns no structuredContent', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + result: { type: 'string' } + }, + required: ['result'] + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Return content instead of structuredContent + return { + content: [{ type: 'text', text: 'This should be structured content' }] + }; + } + throw new Error('Unknown tool'); + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + // Call the tool - should throw error + await expect(client.callTool({ name: 'test-tool' })).rejects.toThrow( + /Tool test-tool has an output schema but did not return structured content/ + ); + }); + + /*** + * Test: Handle Tools Without outputSchema Normally + */ + test('should handle tools without outputSchema normally', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'test-tool', + description: 'A test tool', + inputSchema: { + type: 'object', + properties: {} + } + // No outputSchema + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'test-tool') { + // Return regular content + return { + content: [{ type: 'text', text: 'Normal response' }] + }; + } + throw new Error('Unknown tool'); + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + // Call the tool - should work normally without validation + const result = await client.callTool({ name: 'test-tool' }); + expect(result.content).toEqual([{ type: 'text', text: 'Normal response' }]); + }); + + /*** + * Test: Handle Complex JSON Schema Validation + */ + test('should handle complex JSON schema validation', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'complex-tool', + description: 'A tool with complex schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + name: { type: 'string', minLength: 3 }, + age: { type: 'integer', minimum: 0, maximum: 120 }, + active: { type: 'boolean' }, + tags: { + type: 'array', + items: { type: 'string' }, + minItems: 1 + }, + metadata: { + type: 'object', + properties: { + created: { type: 'string' } + }, + required: ['created'] + } + }, + required: ['name', 'age', 'active', 'tags', 'metadata'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'complex-tool') { + return { + structuredContent: { + name: 'John Doe', + age: 30, + active: true, + tags: ['user', 'admin'], + metadata: { + created: '2023-01-01T00:00:00Z' + } + } + }; + } + throw new Error('Unknown tool'); + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + // Call the tool - should validate successfully + const result = await client.callTool({ name: 'complex-tool' }); + expect(result.structuredContent).toBeDefined(); + const structuredContent = result.structuredContent as { name: string; age: number }; + expect(structuredContent.name).toBe('John Doe'); + expect(structuredContent.age).toBe(30); + }); + + /*** + * Test: Fail Validation with Additional Properties When Not Allowed + */ + test('should fail validation with additional properties when not allowed', async () => { + const server = new Server( + { + name: 'test-server', + version: '1.0.0' + }, + { + capabilities: { + tools: {} + } + } + ); + + // Set up server handlers + server.setRequestHandler(InitializeRequestSchema, async request => ({ + protocolVersion: request.params.protocolVersion, + capabilities: {}, + serverInfo: { + name: 'test-server', + version: '1.0.0' + } + })); + + server.setRequestHandler(ListToolsRequestSchema, async () => ({ + tools: [ + { + name: 'strict-tool', + description: 'A tool with strict schema', + inputSchema: { + type: 'object', + properties: {} + }, + outputSchema: { + type: 'object', + properties: { + name: { type: 'string' } + }, + required: ['name'], + additionalProperties: false + } + } + ] + })); + + server.setRequestHandler(CallToolRequestSchema, async request => { + if (request.params.name === 'strict-tool') { + // Return structured content with extra property + return { + structuredContent: { + name: 'John', + extraField: 'not allowed' + } + }; + } + throw new Error('Unknown tool'); + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // List tools to cache the schemas + await client.listTools(); + + // Call the tool - should throw validation error due to additional property + await expect(client.callTool({ name: 'strict-tool' })).rejects.toThrow( + /Structured content does not match the tool's output schema/ + ); + }); +}); + +describe('getSupportedElicitationModes', () => { + test('should support nothing when capabilities are undefined', () => { + const result = getSupportedElicitationModes(undefined); + expect(result.supportsFormMode).toBe(false); + expect(result.supportsUrlMode).toBe(false); + }); + + test('should default to form mode when capabilities are an empty object', () => { + const result = getSupportedElicitationModes({}); + expect(result.supportsFormMode).toBe(true); + expect(result.supportsUrlMode).toBe(false); + }); + + test('should support form mode when form is explicitly declared', () => { + const result = getSupportedElicitationModes({ form: {} }); + expect(result.supportsFormMode).toBe(true); + expect(result.supportsUrlMode).toBe(false); + }); + + test('should support url mode when url is explicitly declared', () => { + const result = getSupportedElicitationModes({ url: {} }); + expect(result.supportsFormMode).toBe(false); + expect(result.supportsUrlMode).toBe(true); + }); + + test('should support both modes when both are explicitly declared', () => { + const result = getSupportedElicitationModes({ form: {}, url: {} }); + expect(result.supportsFormMode).toBe(true); + expect(result.supportsUrlMode).toBe(true); + }); + + test('should support form mode when form declares applyDefaults', () => { + const result = getSupportedElicitationModes({ form: { applyDefaults: true } }); + expect(result.supportsFormMode).toBe(true); + expect(result.supportsUrlMode).toBe(false); + }); +}); diff --git a/src/examples/server/jsonResponseStreamableHttp.ts b/src/examples/server/jsonResponseStreamableHttp.ts index 8b640777d..c1206d8cd 100644 --- a/src/examples/server/jsonResponseStreamableHttp.ts +++ b/src/examples/server/jsonResponseStreamableHttp.ts @@ -2,7 +2,7 @@ import express, { Request, Response } from 'express'; import { randomUUID } from 'node:crypto'; import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { CallToolResult, isInitializeRequest } from '../../types.js'; import cors from 'cors'; diff --git a/src/examples/server/mcpServerOutputSchema.ts b/src/examples/server/mcpServerOutputSchema.ts index 5d1cab0bd..7ef9f6227 100644 --- a/src/examples/server/mcpServerOutputSchema.ts +++ b/src/examples/server/mcpServerOutputSchema.ts @@ -6,7 +6,7 @@ import { McpServer } from '../../server/mcp.js'; import { StdioServerTransport } from '../../server/stdio.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; const server = new McpServer({ name: 'mcp-output-schema-high-level-example', diff --git a/src/examples/server/simpleSseServer.ts b/src/examples/server/simpleSseServer.ts index b99334369..e07f36010 100644 --- a/src/examples/server/simpleSseServer.ts +++ b/src/examples/server/simpleSseServer.ts @@ -1,7 +1,7 @@ import express, { Request, Response } from 'express'; import { McpServer } from '../../server/mcp.js'; import { SSEServerTransport } from '../../server/sse.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { CallToolResult } from '../../types.js'; /** diff --git a/src/examples/server/simpleStatelessStreamableHttp.ts b/src/examples/server/simpleStatelessStreamableHttp.ts index f71e5db6c..464ea2623 100644 --- a/src/examples/server/simpleStatelessStreamableHttp.ts +++ b/src/examples/server/simpleStatelessStreamableHttp.ts @@ -1,7 +1,7 @@ import express, { Request, Response } from 'express'; import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { CallToolResult, GetPromptResult, ReadResourceResult } from '../../types.js'; import cors from 'cors'; diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index 1765414fa..33568bc82 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -1,6 +1,6 @@ import express, { Request, Response } from 'express'; import { randomUUID } from 'node:crypto'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { getOAuthProtectedResourceMetadataUrl, mcpAuthMetadataRouter } from '../../server/auth/router.js'; diff --git a/src/examples/server/sseAndStreamableHttpCompatibleServer.ts b/src/examples/server/sseAndStreamableHttpCompatibleServer.ts index 50e2e5125..8eb3724c3 100644 --- a/src/examples/server/sseAndStreamableHttpCompatibleServer.ts +++ b/src/examples/server/sseAndStreamableHttpCompatibleServer.ts @@ -3,7 +3,7 @@ import { randomUUID } from 'node:crypto'; import { McpServer } from '../../server/mcp.js'; import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { SSEServerTransport } from '../../server/sse.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { CallToolResult, isInitializeRequest } from '../../types.js'; import { InMemoryEventStore } from '../shared/inMemoryEventStore.js'; import cors from 'cors'; diff --git a/src/examples/server/toolWithSampleServer.ts b/src/examples/server/toolWithSampleServer.ts index ad5a01bdc..c198dc0ec 100644 --- a/src/examples/server/toolWithSampleServer.ts +++ b/src/examples/server/toolWithSampleServer.ts @@ -2,7 +2,7 @@ import { McpServer } from '../../server/mcp.js'; import { StdioServerTransport } from '../../server/stdio.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; const mcpServer = new McpServer({ name: 'tools-with-sample-server', @@ -33,13 +33,12 @@ mcpServer.registerTool( maxTokens: 500 }); + const contents = Array.isArray(response.content) ? response.content : [response.content]; return { - content: [ - { - type: 'text', - text: response.content.type === 'text' ? response.content.text : 'Unable to generate summary' - } - ] + content: contents.map(content => ({ + type: 'text', + text: content.type === 'text' ? content.text : 'Unable to generate summary' + })) }; } ); diff --git a/src/integration-tests/stateManagementStreamableHttp.test.ts b/src/integration-tests/stateManagementStreamableHttp.test.ts index 629b01519..bd61e6104 100644 --- a/src/integration-tests/stateManagementStreamableHttp.test.ts +++ b/src/integration-tests/stateManagementStreamableHttp.test.ts @@ -12,7 +12,7 @@ import { ListPromptsResultSchema, LATEST_PROTOCOL_VERSION } from '../types.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; describe('Streamable HTTP Transport Session Management', () => { // Function to set up the server with optional session management diff --git a/src/integration-tests/taskResumability.test.ts b/src/integration-tests/taskResumability.test.ts index c8393dfe1..d3f54c9d5 100644 --- a/src/integration-tests/taskResumability.test.ts +++ b/src/integration-tests/taskResumability.test.ts @@ -6,7 +6,7 @@ import { StreamableHTTPClientTransport } from '../client/streamableHttp.js'; import { McpServer } from '../server/mcp.js'; import { StreamableHTTPServerTransport } from '../server/streamableHttp.js'; import { CallToolResultSchema, LoggingMessageNotificationSchema } from '../types.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { InMemoryEventStore } from '../examples/shared/inMemoryEventStore.js'; describe('Transport resumability', () => { @@ -193,8 +193,14 @@ describe('Transport resumability', () => { } ); - // Wait for some notifications to arrive (not all) - shorter wait time - await new Promise(resolve => setTimeout(resolve, 20)); + // Fix for node 18 test failures, allow some time for notifications to arrive + const maxWaitTime = 2000; // 2 seconds max wait + const pollInterval = 10; // Check every 10ms + const startTime = Date.now(); + while (notifications.length === 0 && Date.now() - startTime < maxWaitTime) { + // Wait for some notifications to arrive (not all) - shorter wait time + await new Promise(resolve => setTimeout(resolve, pollInterval)); + } // Verify we received some notifications and lastEventId was updated expect(notifications.length).toBeGreaterThan(0); diff --git a/src/integration-tests/v3/stateManagementStreamableHttp.v3.test.ts b/src/integration-tests/v3/stateManagementStreamableHttp.v3.test.ts new file mode 100644 index 000000000..b47306142 --- /dev/null +++ b/src/integration-tests/v3/stateManagementStreamableHttp.v3.test.ts @@ -0,0 +1,357 @@ +import { createServer, type Server } from 'node:http'; +import { AddressInfo } from 'node:net'; +import { randomUUID } from 'node:crypto'; +import { Client } from '../../client/index.js'; +import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; +import { McpServer } from '../../server/mcp.js'; +import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; +import { + CallToolResultSchema, + ListToolsResultSchema, + ListResourcesResultSchema, + ListPromptsResultSchema, + LATEST_PROTOCOL_VERSION +} from '../../types.js'; +import * as z from 'zod/v3'; + +describe('Streamable HTTP Transport Session Management', () => { + // Function to set up the server with optional session management + async function setupServer(withSessionManagement: boolean) { + const server: Server = createServer(); + const mcpServer = new McpServer( + { name: 'test-server', version: '1.0.0' }, + { + capabilities: { + logging: {}, + tools: {}, + resources: {}, + prompts: {} + } + } + ); + + // Add a simple resource + mcpServer.resource('test-resource', '/test', { description: 'A test resource' }, async () => ({ + contents: [ + { + uri: '/test', + text: 'This is a test resource content' + } + ] + })); + + mcpServer.prompt('test-prompt', 'A test prompt', async () => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: 'This is a test prompt' + } + } + ] + })); + + mcpServer.tool( + 'greet', + 'A simple greeting tool', + { + name: z.string().describe('Name to greet').default('World') + }, + async ({ name }) => { + return { + content: [{ type: 'text', text: `Hello, ${name}!` }] + }; + } + ); + + // Create transport with or without session management + const serverTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: withSessionManagement + ? () => randomUUID() // With session management, generate UUID + : undefined // Without session management, return undefined + }); + + await mcpServer.connect(serverTransport); + + server.on('request', async (req, res) => { + await serverTransport.handleRequest(req, res); + }); + + // Start the server on a random port + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, mcpServer, serverTransport, baseUrl }; + } + + describe('Stateless Mode', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const setup = await setupServer(false); + server = setup.server; + mcpServer = setup.mcpServer; + serverTransport = setup.serverTransport; + baseUrl = setup.baseUrl; + }); + + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); + }); + + it('should support multiple client connections', async () => { + // Create and connect a client + const client1 = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport1 = new StreamableHTTPClientTransport(baseUrl); + await client1.connect(transport1); + + // Verify that no session ID was set + expect(transport1.sessionId).toBeUndefined(); + + // List available tools + await client1.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + + const client2 = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport2 = new StreamableHTTPClientTransport(baseUrl); + await client2.connect(transport2); + + // Verify that no session ID was set + expect(transport2.sessionId).toBeUndefined(); + + // List available tools + await client2.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + }); + it('should operate without session management', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify that no session ID was set + expect(transport.sessionId).toBeUndefined(); + + // List available tools + const toolsResult = await client.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + + // Verify tools are accessible + expect(toolsResult.tools).toContainEqual( + expect.objectContaining({ + name: 'greet' + }) + ); + + // List available resources + const resourcesResult = await client.request( + { + method: 'resources/list', + params: {} + }, + ListResourcesResultSchema + ); + + // Verify resources result structure + expect(resourcesResult).toHaveProperty('resources'); + + // List available prompts + const promptsResult = await client.request( + { + method: 'prompts/list', + params: {} + }, + ListPromptsResultSchema + ); + + // Verify prompts result structure + expect(promptsResult).toHaveProperty('prompts'); + expect(promptsResult.prompts).toContainEqual( + expect.objectContaining({ + name: 'test-prompt' + }) + ); + + // Call the greeting tool + const greetingResult = await client.request( + { + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Stateless Transport' + } + } + }, + CallToolResultSchema + ); + + // Verify tool result + expect(greetingResult.content).toEqual([{ type: 'text', text: 'Hello, Stateless Transport!' }]); + + // Clean up + await transport.close(); + }); + + it('should set protocol version after connecting', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + + // Verify protocol version is not set before connecting + expect(transport.protocolVersion).toBeUndefined(); + + await client.connect(transport); + + // Verify protocol version is set after connecting + expect(transport.protocolVersion).toBe(LATEST_PROTOCOL_VERSION); + + // Clean up + await transport.close(); + }); + }); + + describe('Stateful Mode', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const setup = await setupServer(true); + server = setup.server; + mcpServer = setup.mcpServer; + serverTransport = setup.serverTransport; + baseUrl = setup.baseUrl; + }); + + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); + }); + + it('should operate with session management', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify that a session ID was set + expect(transport.sessionId).toBeDefined(); + expect(typeof transport.sessionId).toBe('string'); + + // List available tools + const toolsResult = await client.request( + { + method: 'tools/list', + params: {} + }, + ListToolsResultSchema + ); + + // Verify tools are accessible + expect(toolsResult.tools).toContainEqual( + expect.objectContaining({ + name: 'greet' + }) + ); + + // List available resources + const resourcesResult = await client.request( + { + method: 'resources/list', + params: {} + }, + ListResourcesResultSchema + ); + + // Verify resources result structure + expect(resourcesResult).toHaveProperty('resources'); + + // List available prompts + const promptsResult = await client.request( + { + method: 'prompts/list', + params: {} + }, + ListPromptsResultSchema + ); + + // Verify prompts result structure + expect(promptsResult).toHaveProperty('prompts'); + expect(promptsResult.prompts).toContainEqual( + expect.objectContaining({ + name: 'test-prompt' + }) + ); + + // Call the greeting tool + const greetingResult = await client.request( + { + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Stateful Transport' + } + } + }, + CallToolResultSchema + ); + + // Verify tool result + expect(greetingResult.content).toEqual([{ type: 'text', text: 'Hello, Stateful Transport!' }]); + + // Clean up + await transport.close(); + }); + }); +}); diff --git a/src/integration-tests/v3/taskResumability.v3.test.ts b/src/integration-tests/v3/taskResumability.v3.test.ts new file mode 100644 index 000000000..7c7ea927e --- /dev/null +++ b/src/integration-tests/v3/taskResumability.v3.test.ts @@ -0,0 +1,270 @@ +import { createServer, type Server } from 'node:http'; +import { AddressInfo } from 'node:net'; +import { randomUUID } from 'node:crypto'; +import { Client } from '../../client/index.js'; +import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; +import { McpServer } from '../../server/mcp.js'; +import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; +import { CallToolResultSchema, LoggingMessageNotificationSchema } from '../../types.js'; +import * as z from 'zod/v3'; +import { InMemoryEventStore } from '../../examples/shared/inMemoryEventStore.js'; + +describe('Transport resumability', () => { + let server: Server; + let mcpServer: McpServer; + let serverTransport: StreamableHTTPServerTransport; + let baseUrl: URL; + let eventStore: InMemoryEventStore; + + beforeEach(async () => { + // Create event store for resumability + eventStore = new InMemoryEventStore(); + + // Create a simple MCP server + mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + // Add a simple notification tool that completes quickly + mcpServer.tool( + 'send-notification', + 'Sends a single notification', + { + message: z.string().describe('Message to send').default('Test notification') + }, + async ({ message }, { sendNotification }) => { + // Send notification immediately + await sendNotification({ + method: 'notifications/message', + params: { + level: 'info', + data: message + } + }); + + return { + content: [{ type: 'text', text: 'Notification sent' }] + }; + } + ); + + // Add a long-running tool that sends multiple notifications + mcpServer.tool( + 'run-notifications', + 'Sends multiple notifications over time', + { + count: z.number().describe('Number of notifications to send').default(10), + interval: z.number().describe('Interval between notifications in ms').default(50) + }, + async ({ count, interval }, { sendNotification }) => { + // Send notifications at specified intervals + for (let i = 0; i < count; i++) { + await sendNotification({ + method: 'notifications/message', + params: { + level: 'info', + data: `Notification ${i + 1} of ${count}` + } + }); + + // Wait for the specified interval before sending next notification + if (i < count - 1) { + await new Promise(resolve => setTimeout(resolve, interval)); + } + } + + return { + content: [{ type: 'text', text: `Sent ${count} notifications` }] + }; + } + ); + + // Create a transport with the event store + serverTransport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + eventStore + }); + + // Connect the transport to the MCP server + await mcpServer.connect(serverTransport); + + // Create and start an HTTP server + server = createServer(async (req, res) => { + await serverTransport.handleRequest(req, res); + }); + + // Start the server on a random port + baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + }); + + afterEach(async () => { + // Clean up resources + await mcpServer.close().catch(() => {}); + await serverTransport.close().catch(() => {}); + server.close(); + }); + + it('should store session ID when client connects', async () => { + // Create and connect a client + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + await client.connect(transport); + + // Verify session ID was generated + expect(transport.sessionId).toBeDefined(); + + // Clean up + await transport.close(); + }); + + it('should have session ID functionality', async () => { + // The ability to store a session ID when connecting + const client = new Client({ + name: 'test-client-reconnection', + version: '1.0.0' + }); + + const transport = new StreamableHTTPClientTransport(baseUrl); + + // Make sure the client can connect and get a session ID + await client.connect(transport); + expect(transport.sessionId).toBeDefined(); + + // Clean up + await transport.close(); + }); + + // This test demonstrates the capability to resume long-running tools + // across client disconnection/reconnection + it('should resume long-running notifications with lastEventId', async () => { + // Create unique client ID for this test + const clientTitle = 'test-client-long-running'; + const notifications = []; + let lastEventId: string | undefined; + + // Create first client + const client1 = new Client({ + title: clientTitle, + name: 'test-client', + version: '1.0.0' + }); + + // Set up notification handler for first client + client1.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + if (notification.method === 'notifications/message') { + notifications.push(notification.params); + } + }); + + // Connect first client + const transport1 = new StreamableHTTPClientTransport(baseUrl); + await client1.connect(transport1); + const sessionId = transport1.sessionId; + expect(sessionId).toBeDefined(); + + // Start a long-running notification stream with tracking of lastEventId + const onLastEventIdUpdate = vi.fn((eventId: string) => { + lastEventId = eventId; + }); + expect(lastEventId).toBeUndefined(); + // Start the notification tool with event tracking using request + const toolPromise = client1.request( + { + method: 'tools/call', + params: { + name: 'run-notifications', + arguments: { + count: 3, + interval: 10 + } + } + }, + CallToolResultSchema, + { + resumptionToken: lastEventId, + onresumptiontoken: onLastEventIdUpdate + } + ); + + // Wait for some notifications to arrive (not all) - shorter wait time + await new Promise(resolve => setTimeout(resolve, 20)); + + // Verify we received some notifications and lastEventId was updated + expect(notifications.length).toBeGreaterThan(0); + expect(notifications.length).toBeLessThan(4); + expect(onLastEventIdUpdate).toHaveBeenCalled(); + expect(lastEventId).toBeDefined(); + + // Disconnect first client without waiting for completion + // When we close the connection, it will cause a ConnectionClosed error for + // any in-progress requests, which is expected behavior + await transport1.close(); + // Save the promise so we can catch it after closing + const catchPromise = toolPromise.catch(err => { + // This error is expected - the connection was intentionally closed + if (err?.code !== -32000) { + // ConnectionClosed error code + console.error('Unexpected error type during transport close:', err); + } + }); + + // Add a short delay to ensure clean disconnect before reconnecting + await new Promise(resolve => setTimeout(resolve, 10)); + + // Wait for the rejection to be handled + await catchPromise; + + // Create second client with same client ID + const client2 = new Client({ + title: clientTitle, + name: 'test-client', + version: '1.0.0' + }); + + // Set up notification handler for second client + client2.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + if (notification.method === 'notifications/message') { + notifications.push(notification.params); + } + }); + + // Connect second client with same session ID + const transport2 = new StreamableHTTPClientTransport(baseUrl, { + sessionId + }); + await client2.connect(transport2); + + // Resume the notification stream using lastEventId + // This is the key part - we're resuming the same long-running tool using lastEventId + await client2.request( + { + method: 'tools/call', + params: { + name: 'run-notifications', + arguments: { + count: 1, + interval: 5 + } + } + }, + CallToolResultSchema, + { + resumptionToken: lastEventId, // Pass the lastEventId from the previous session + onresumptiontoken: onLastEventIdUpdate + } + ); + + // Verify we eventually received at leaset a few motifications + expect(notifications.length).toBeGreaterThan(1); + + // Clean up + await transport2.close(); + }); +}); diff --git a/src/server/auth/handlers/authorize.ts b/src/server/auth/handlers/authorize.ts index ef15770b9..dcb6c03ec 100644 --- a/src/server/auth/handlers/authorize.ts +++ b/src/server/auth/handlers/authorize.ts @@ -1,5 +1,5 @@ import { RequestHandler } from 'express'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import express from 'express'; import { OAuthServerProvider } from '../provider.js'; import { rateLimit, Options as RateLimitOptions } from 'express-rate-limit'; diff --git a/src/server/auth/handlers/token.ts b/src/server/auth/handlers/token.ts index c387ff7bf..75a20329d 100644 --- a/src/server/auth/handlers/token.ts +++ b/src/server/auth/handlers/token.ts @@ -1,4 +1,4 @@ -import { z } from 'zod'; +import * as z from 'zod/v4'; import express, { RequestHandler } from 'express'; import { OAuthServerProvider } from '../provider.js'; import cors from 'cors'; diff --git a/src/server/auth/middleware/clientAuth.ts b/src/server/auth/middleware/clientAuth.ts index 9969b8724..52611a660 100644 --- a/src/server/auth/middleware/clientAuth.ts +++ b/src/server/auth/middleware/clientAuth.ts @@ -1,4 +1,4 @@ -import { z } from 'zod'; +import * as z from 'zod/v4'; import { RequestHandler } from 'express'; import { OAuthRegisteredClientsStore } from '../clients.js'; import { OAuthClientInformationFull } from '../../../shared/auth.js'; diff --git a/src/server/completable.test.ts b/src/server/completable.test.ts index b5effc272..fa836fec5 100644 --- a/src/server/completable.test.ts +++ b/src/server/completable.test.ts @@ -1,5 +1,5 @@ -import { z } from 'zod'; -import { completable } from './completable.js'; +import * as z from 'zod/v4'; +import { completable, getCompleter } from './completable.js'; describe('completable', () => { it('preserves types and values of underlying schema', () => { @@ -14,27 +14,35 @@ describe('completable', () => { const completions = ['foo', 'bar', 'baz']; const schema = completable(z.string(), () => completions); - expect(await schema._def.complete('')).toEqual(completions); + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!('')).toEqual(completions); }); it('allows async completion functions', async () => { const completions = ['foo', 'bar', 'baz']; const schema = completable(z.string(), async () => completions); - expect(await schema._def.complete('')).toEqual(completions); + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!('')).toEqual(completions); }); it('passes current value to completion function', async () => { const schema = completable(z.string(), value => [value + '!']); - expect(await schema._def.complete('test')).toEqual(['test!']); + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!('test')).toEqual(['test!']); }); it('works with number schemas', async () => { const schema = completable(z.number(), () => [1, 2, 3]); expect(schema.parse(1)).toBe(1); - expect(await schema._def.complete(0)).toEqual([1, 2, 3]); + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!(0)).toEqual([1, 2, 3]); }); it('preserves schema description', () => { diff --git a/src/server/completable.ts b/src/server/completable.ts index 67d91c383..be067ac55 100644 --- a/src/server/completable.ts +++ b/src/server/completable.ts @@ -1,79 +1,67 @@ -import { ZodTypeAny, ZodTypeDef, ZodType, ParseInput, ParseReturnType, RawCreateParams, ZodErrorMap, ProcessedCreateParams } from 'zod'; +import { AnySchema, SchemaInput } from './zod-compat.js'; -export enum McpZodTypeKind { - Completable = 'McpCompletable' -} +export const COMPLETABLE_SYMBOL: unique symbol = Symbol.for('mcp.completable'); -export type CompleteCallback = ( - value: T['_input'], +export type CompleteCallback = ( + value: SchemaInput, context?: { arguments?: Record; } -) => T['_input'][] | Promise; +) => SchemaInput[] | Promise[]>; -export interface CompletableDef extends ZodTypeDef { - type: T; +export type CompletableMeta = { complete: CompleteCallback; - typeName: McpZodTypeKind.Completable; -} +}; -export class Completable extends ZodType, T['_input']> { - _parse(input: ParseInput): ParseReturnType { - const { ctx } = this._processInputParams(input); - const data = ctx.data; - return this._def.type._parse({ - data, - path: ctx.path, - parent: ctx - }); - } +export type CompletableSchema = T & { + [COMPLETABLE_SYMBOL]: CompletableMeta; +}; - unwrap() { - return this._def.type; - } +/** + * Wraps a Zod type to provide autocompletion capabilities. Useful for, e.g., prompt arguments in MCP. + * Works with both Zod v3 and v4 schemas. + */ +export function completable(schema: T, complete: CompleteCallback): CompletableSchema { + Object.defineProperty(schema as object, COMPLETABLE_SYMBOL, { + value: { complete } as CompletableMeta, + enumerable: false, + writable: false, + configurable: false + }); + return schema as CompletableSchema; +} + +/** + * Checks if a schema is completable (has completion metadata). + */ +export function isCompletable(schema: unknown): schema is CompletableSchema { + return !!schema && typeof schema === 'object' && COMPLETABLE_SYMBOL in (schema as object); +} - static create = ( - type: T, - params: RawCreateParams & { - complete: CompleteCallback; - } - ): Completable => { - return new Completable({ - type, - typeName: McpZodTypeKind.Completable, - complete: params.complete, - ...processCreateParams(params) - }); - }; +/** + * Gets the completer callback from a completable schema, if it exists. + */ +export function getCompleter(schema: T): CompleteCallback | undefined { + const meta = (schema as unknown as { [COMPLETABLE_SYMBOL]?: CompletableMeta })[COMPLETABLE_SYMBOL]; + return meta?.complete as CompleteCallback | undefined; } /** - * Wraps a Zod type to provide autocompletion capabilities. Useful for, e.g., prompt arguments in MCP. + * Unwraps a completable schema to get the underlying schema. + * For backward compatibility with code that called `.unwrap()`. */ -export function completable(schema: T, complete: CompleteCallback): Completable { - return Completable.create(schema, { ...schema._def, complete }); +export function unwrapCompletable(schema: CompletableSchema): T { + return schema; } -// Not sure why this isn't exported from Zod: -// https://github.com/colinhacks/zod/blob/f7ad26147ba291cb3fb257545972a8e00e767470/src/types.ts#L130 -function processCreateParams(params: RawCreateParams): ProcessedCreateParams { - if (!params) return {}; - const { errorMap, invalid_type_error, required_error, description } = params; - if (errorMap && (invalid_type_error || required_error)) { - throw new Error(`Can't use "invalid_type_error" or "required_error" in conjunction with custom error map.`); - } - if (errorMap) return { errorMap: errorMap, description }; - const customMap: ZodErrorMap = (iss, ctx) => { - const { message } = params; +// Legacy exports for backward compatibility +// These types are deprecated but kept for existing code +export enum McpZodTypeKind { + Completable = 'McpCompletable' +} - if (iss.code === 'invalid_enum_value') { - return { message: message ?? ctx.defaultError }; - } - if (typeof ctx.data === 'undefined') { - return { message: message ?? required_error ?? ctx.defaultError }; - } - if (iss.code !== 'invalid_type') return { message: ctx.defaultError }; - return { message: message ?? invalid_type_error ?? ctx.defaultError }; - }; - return { errorMap: customMap, description }; +export interface CompletableDef { + type: T; + complete: CompleteCallback; + typeName: McpZodTypeKind.Completable; } diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 16a1d94bd..36665095e 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -1,5 +1,5 @@ /* eslint-disable @typescript-eslint/no-unused-vars */ -import { z } from 'zod'; +import * as z from 'zod/v4'; import { Client } from '../client/index.js'; import { InMemoryTransport } from '../inMemory.js'; import type { Transport } from '../shared/transport.js'; diff --git a/src/server/index.ts b/src/server/index.ts index 60751cc38..8de1a3cc4 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -354,7 +354,7 @@ export class Server< params: LegacyElicitRequestFormParams | ElicitRequestFormParams | ElicitRequestURLParams, options?: RequestOptions ): Promise { - const mode = 'mode' in params ? params.mode : 'form'; + const mode = ('mode' in params ? params.mode : 'form') as 'form' | 'url'; switch (mode) { case 'url': { @@ -370,7 +370,9 @@ export class Server< throw new Error('Client does not support form elicitation.'); } const formParams: ElicitRequestFormParams = - 'mode' in params ? (params as ElicitRequestFormParams) : { ...(params as LegacyElicitRequestFormParams), mode: 'form' }; + 'mode' in params + ? (params as ElicitRequestFormParams) + : ({ ...(params as LegacyElicitRequestFormParams), mode: 'form' } as ElicitRequestFormParams); const result = await this.request({ method: 'elicitation/create', params: formParams }, ElicitResultSchema, options); diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index a6310173f..23798c138 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -1,4 +1,4 @@ -import { z } from 'zod'; +import * as z from 'zod/v4'; import { Client } from '../client/index.js'; import { InMemoryTransport } from '../inMemory.js'; import { getDisplayName } from '../shared/metadataUtils.js'; diff --git a/src/server/mcp.ts b/src/server/mcp.ts index 3348d57e1..b9b6d5596 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -1,6 +1,20 @@ import { Server, ServerOptions } from './index.js'; -import { zodToJsonSchema } from 'zod-to-json-schema'; -import { z, ZodRawShape, ZodObject, ZodString, ZodTypeAny, ZodType, ZodTypeDef, ZodOptional } from 'zod'; +import { + AnySchema, + AnyObjectSchema, + ZodRawShapeCompat, + SchemaOutput, + ShapeOutput, + normalizeObjectSchema, + safeParseAsync, + getObjectShape, + objectFromShape, + getParseErrorMessage, + getSchemaDescription, + isSchemaOptional, + getLiteralValue +} from './zod-compat.js'; +import { toJsonSchemaCompat } from './zod-json-schema-compat.js'; import { Implementation, Tool, @@ -36,7 +50,7 @@ import { assertCompleteRequestPrompt, assertCompleteRequestResourceTemplate } from '../types.js'; -import { Completable, CompletableDef } from './completable.js'; +import { isCompletable, getCompleter } from './completable.js'; import { UriTemplate, Variables } from '../shared/uriTemplate.js'; import { RequestHandlerExtra } from '../shared/protocol.js'; import { Transport } from '../shared/transport.js'; @@ -87,8 +101,8 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(ListToolsRequestSchema.shape.method.value); - this.server.assertCanSetRequestHandler(CallToolRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(getMethodValue(ListToolsRequestSchema)); + this.server.assertCanSetRequestHandler(getMethodValue(CallToolRequestSchema)); this.server.registerCapabilities({ tools: { @@ -106,21 +120,27 @@ export class McpServer { name, title: tool.title, description: tool.description, - inputSchema: tool.inputSchema - ? (zodToJsonSchema(tool.inputSchema, { - strictUnions: true, - pipeStrategy: 'input' - }) as Tool['inputSchema']) - : EMPTY_OBJECT_JSON_SCHEMA, + inputSchema: (() => { + const obj = normalizeObjectSchema(tool.inputSchema); + return obj + ? (toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: 'input' + }) as Tool['inputSchema']) + : EMPTY_OBJECT_JSON_SCHEMA; + })(), annotations: tool.annotations, _meta: tool._meta }; if (tool.outputSchema) { - toolDefinition.outputSchema = zodToJsonSchema(tool.outputSchema, { - strictUnions: true, - pipeStrategy: 'output' - }) as Tool['outputSchema']; + const obj = normalizeObjectSchema(tool.outputSchema); + if (obj) { + toolDefinition.outputSchema = toJsonSchemaCompat(obj, { + strictUnions: true, + pipeStrategy: 'output' + }) as Tool['outputSchema']; + } } return toolDefinition; @@ -143,12 +163,16 @@ export class McpServer { } if (tool.inputSchema) { - const cb = tool.callback as ToolCallback; - const parseResult = await tool.inputSchema.safeParseAsync(request.params.arguments); + const cb = tool.callback as ToolCallback; + // Try to normalize to object schema first (for raw shapes and object schemas) + // If that fails, use the schema directly (for union/intersection/etc) + const inputObj = normalizeObjectSchema(tool.inputSchema); + const schemaToParse = inputObj ?? (tool.inputSchema as AnySchema); + const parseResult = await safeParseAsync(schemaToParse, request.params.arguments); if (!parseResult.success) { throw new McpError( ErrorCode.InvalidParams, - `Input validation error: Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}` + `Input validation error: Invalid arguments for tool ${request.params.name}: ${getParseErrorMessage(parseResult.error)}` ); } @@ -169,11 +193,12 @@ export class McpServer { } // if the tool has an output schema, validate structured content - const parseResult = await tool.outputSchema.safeParseAsync(result.structuredContent); + const outputObj = normalizeObjectSchema(tool.outputSchema) as AnyObjectSchema; + const parseResult = await safeParseAsync(outputObj, result.structuredContent); if (!parseResult.success) { throw new McpError( ErrorCode.InvalidParams, - `Output validation error: Invalid structured content for tool ${request.params.name}: ${parseResult.error.message}` + `Output validation error: Invalid structured content for tool ${request.params.name}: ${getParseErrorMessage(parseResult.error)}` ); } } @@ -217,7 +242,7 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(CompleteRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(getMethodValue(CompleteRequestSchema)); this.server.registerCapabilities({ completions: {} @@ -255,13 +280,17 @@ export class McpServer { return EMPTY_COMPLETION_RESULT; } - const field = prompt.argsSchema.shape[request.params.argument.name]; - if (!(field instanceof Completable)) { + const promptShape = getObjectShape(prompt.argsSchema); + const field = promptShape?.[request.params.argument.name]; + if (!isCompletable(field)) { return EMPTY_COMPLETION_RESULT; } - const def: CompletableDef = field._def; - const suggestions = await def.complete(request.params.argument.value, request.params.context); + const completer = getCompleter(field); + if (!completer) { + return EMPTY_COMPLETION_RESULT; + } + const suggestions = await completer(request.params.argument.value, request.params.context); return createCompletionResult(suggestions); } @@ -296,9 +325,9 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(ListResourcesRequestSchema.shape.method.value); - this.server.assertCanSetRequestHandler(ListResourceTemplatesRequestSchema.shape.method.value); - this.server.assertCanSetRequestHandler(ReadResourceRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(getMethodValue(ListResourcesRequestSchema)); + this.server.assertCanSetRequestHandler(getMethodValue(ListResourceTemplatesRequestSchema)); + this.server.assertCanSetRequestHandler(getMethodValue(ReadResourceRequestSchema)); this.server.registerCapabilities({ resources: { @@ -379,8 +408,8 @@ export class McpServer { return; } - this.server.assertCanSetRequestHandler(ListPromptsRequestSchema.shape.method.value); - this.server.assertCanSetRequestHandler(GetPromptRequestSchema.shape.method.value); + this.server.assertCanSetRequestHandler(getMethodValue(ListPromptsRequestSchema)); + this.server.assertCanSetRequestHandler(getMethodValue(GetPromptRequestSchema)); this.server.registerCapabilities({ prompts: { @@ -415,11 +444,12 @@ export class McpServer { } if (prompt.argsSchema) { - const parseResult = await prompt.argsSchema.safeParseAsync(request.params.arguments); + const argsObj = normalizeObjectSchema(prompt.argsSchema) as AnyObjectSchema; + const parseResult = await safeParseAsync(argsObj, request.params.arguments); if (!parseResult.success) { throw new McpError( ErrorCode.InvalidParams, - `Invalid arguments for prompt ${request.params.name}: ${parseResult.error.message}` + `Invalid arguments for prompt ${request.params.name}: ${getParseErrorMessage(parseResult.error)}` ); } @@ -637,7 +667,7 @@ export class McpServer { const registeredPrompt: RegisteredPrompt = { title, description, - argsSchema: argsSchema === undefined ? undefined : z.object(argsSchema), + argsSchema: argsSchema === undefined ? undefined : objectFromShape(argsSchema), callback, enabled: true, disable: () => registeredPrompt.update({ enabled: false }), @@ -650,7 +680,7 @@ export class McpServer { } if (typeof updates.title !== 'undefined') registeredPrompt.title = updates.title; if (typeof updates.description !== 'undefined') registeredPrompt.description = updates.description; - if (typeof updates.argsSchema !== 'undefined') registeredPrompt.argsSchema = z.object(updates.argsSchema); + if (typeof updates.argsSchema !== 'undefined') registeredPrompt.argsSchema = objectFromShape(updates.argsSchema); if (typeof updates.callback !== 'undefined') registeredPrompt.callback = updates.callback; if (typeof updates.enabled !== 'undefined') registeredPrompt.enabled = updates.enabled; this.sendPromptListChanged(); @@ -664,11 +694,11 @@ export class McpServer { name: string, title: string | undefined, description: string | undefined, - inputSchema: ZodRawShape | ZodType | undefined, - outputSchema: ZodRawShape | ZodType | undefined, + inputSchema: ZodRawShapeCompat | AnySchema | undefined, + outputSchema: ZodRawShapeCompat | AnySchema | undefined, annotations: ToolAnnotations | undefined, _meta: Record | undefined, - callback: ToolCallback + callback: ToolCallback ): RegisteredTool { // Validate tool name according to SEP specification validateAndWarnToolName(name); @@ -695,7 +725,7 @@ export class McpServer { } if (typeof updates.title !== 'undefined') registeredTool.title = updates.title; if (typeof updates.description !== 'undefined') registeredTool.description = updates.description; - if (typeof updates.paramsSchema !== 'undefined') registeredTool.inputSchema = z.object(updates.paramsSchema); + if (typeof updates.paramsSchema !== 'undefined') registeredTool.inputSchema = objectFromShape(updates.paramsSchema); if (typeof updates.callback !== 'undefined') registeredTool.callback = updates.callback; if (typeof updates.annotations !== 'undefined') registeredTool.annotations = updates.annotations; if (typeof updates._meta !== 'undefined') registeredTool._meta = updates._meta; @@ -731,7 +761,11 @@ export class McpServer { * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. * @deprecated Use `registerTool` instead. */ - tool(name: string, paramsSchemaOrAnnotations: Args | ToolAnnotations, cb: ToolCallback): RegisteredTool; + tool( + name: string, + paramsSchemaOrAnnotations: Args | ToolAnnotations, + cb: ToolCallback + ): RegisteredTool; /** * Registers a tool `name` (with a description) taking either parameter schema or annotations. @@ -742,7 +776,7 @@ export class McpServer { * between ToolAnnotations and ZodRawShape during overload resolution, as both are plain object types. * @deprecated Use `registerTool` instead. */ - tool( + tool( name: string, description: string, paramsSchemaOrAnnotations: Args | ToolAnnotations, @@ -753,13 +787,18 @@ export class McpServer { * Registers a tool with both parameter schema and annotations. * @deprecated Use `registerTool` instead. */ - tool(name: string, paramsSchema: Args, annotations: ToolAnnotations, cb: ToolCallback): RegisteredTool; + tool( + name: string, + paramsSchema: Args, + annotations: ToolAnnotations, + cb: ToolCallback + ): RegisteredTool; /** * Registers a tool with description, parameter schema, and annotations. * @deprecated Use `registerTool` instead. */ - tool( + tool( name: string, description: string, paramsSchema: Args, @@ -776,8 +815,8 @@ export class McpServer { } let description: string | undefined; - let inputSchema: ZodRawShape | undefined; - let outputSchema: ZodRawShape | undefined; + let inputSchema: ZodRawShapeCompat | undefined; + let outputSchema: ZodRawShapeCompat | undefined; let annotations: ToolAnnotations | undefined; // Tool properties are passed as separate arguments, with omissions allowed. @@ -795,7 +834,7 @@ export class McpServer { if (isZodRawShape(firstArg)) { // We have a params schema as the first arg - inputSchema = rest.shift() as ZodRawShape; + inputSchema = rest.shift() as ZodRawShapeCompat; // Check if the next arg is potentially annotations if (rest.length > 1 && typeof rest[0] === 'object' && rest[0] !== null && !isZodRawShape(rest[0])) { @@ -810,7 +849,7 @@ export class McpServer { annotations = rest.shift() as ToolAnnotations; } } - const callback = rest[0] as ToolCallback; + const callback = rest[0] as ToolCallback; return this._createRegisteredTool(name, undefined, description, inputSchema, outputSchema, annotations, undefined, callback); } @@ -818,7 +857,7 @@ export class McpServer { /** * Registers a tool with a config object and callback. */ - registerTool, OutputArgs extends ZodRawShape | ZodType>( + registerTool( name: string, config: { title?: string; @@ -844,7 +883,7 @@ export class McpServer { outputSchema, annotations, _meta, - cb as ToolCallback + cb as ToolCallback ); } @@ -1047,27 +1086,27 @@ export class ResourceTemplate { * - `content` if the tool does not have an outputSchema * - Both fields are optional but typically one should be provided */ -export type ToolCallback = undefined> = Args extends ZodRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra - ) => CallToolResult | Promise - : Args extends ZodType - ? (args: T, extra: RequestHandlerExtra) => CallToolResult | Promise +export type ToolCallback = Args extends ZodRawShapeCompat + ? (args: ShapeOutput, extra: RequestHandlerExtra) => CallToolResult | Promise + : Args extends AnySchema + ? ( + args: SchemaOutput, + extra: RequestHandlerExtra + ) => CallToolResult | Promise : (extra: RequestHandlerExtra) => CallToolResult | Promise; export type RegisteredTool = { title?: string; description?: string; - inputSchema?: ZodType; - outputSchema?: ZodType; + inputSchema?: AnySchema; + outputSchema?: AnySchema; annotations?: ToolAnnotations; _meta?: Record; - callback: ToolCallback; + callback: ToolCallback; enabled: boolean; enable(): void; disable(): void; - update(updates: { + update(updates: { name?: string | null; title?: string; description?: string; @@ -1086,8 +1125,8 @@ const EMPTY_OBJECT_JSON_SCHEMA = { properties: {} }; -// Helper to check if an object is a Zod schema (ZodRawShape) -function isZodRawShape(obj: unknown): obj is ZodRawShape { +// Helper to check if an object is a Zod schema (ZodRawShapeCompat) +function isZodRawShape(obj: unknown): obj is ZodRawShapeCompat { if (typeof obj !== 'object' || obj === null) return false; const isEmptyObject = Object.keys(obj).length === 0; @@ -1097,7 +1136,7 @@ function isZodRawShape(obj: unknown): obj is ZodRawShape { return isEmptyObject || Object.values(obj as object).some(isZodTypeLike); } -function isZodTypeLike(value: unknown): value is ZodType { +function isZodTypeLike(value: unknown): value is AnySchema { return ( value !== null && typeof value === 'object' && @@ -1112,13 +1151,13 @@ function isZodTypeLike(value: unknown): value is ZodType { * Converts a provided Zod schema to a Zod object if it is a ZodRawShape, * otherwise returns the schema as is. */ -function getZodSchemaObject(schema: ZodRawShape | ZodType | undefined): ZodType | undefined { +function getZodSchemaObject(schema: ZodRawShapeCompat | AnySchema | undefined): AnySchema | undefined { if (!schema) { return undefined; } if (isZodRawShape(schema)) { - return z.object(schema); + return objectFromShape(schema); } return schema; @@ -1191,21 +1230,16 @@ export type RegisteredResourceTemplate = { remove(): void; }; -type PromptArgsRawShape = { - [k: string]: ZodType | ZodOptional>; -}; +type PromptArgsRawShape = ZodRawShapeCompat; export type PromptCallback = Args extends PromptArgsRawShape - ? ( - args: z.objectOutputType, - extra: RequestHandlerExtra - ) => GetPromptResult | Promise + ? (args: ShapeOutput, extra: RequestHandlerExtra) => GetPromptResult | Promise : (extra: RequestHandlerExtra) => GetPromptResult | Promise; export type RegisteredPrompt = { title?: string; description?: string; - argsSchema?: ZodObject; + argsSchema?: AnyObjectSchema; callback: PromptCallback; enabled: boolean; enable(): void; @@ -1221,14 +1255,36 @@ export type RegisteredPrompt = { remove(): void; }; -function promptArgumentsFromSchema(schema: ZodObject): PromptArgument[] { - return Object.entries(schema.shape).map( - ([name, field]): PromptArgument => ({ +function promptArgumentsFromSchema(schema: AnyObjectSchema): PromptArgument[] { + const shape = getObjectShape(schema); + if (!shape) return []; + return Object.entries(shape).map(([name, field]): PromptArgument => { + // Get description - works for both v3 and v4 + const description = getSchemaDescription(field); + // Check if optional - works for both v3 and v4 + const isOptional = isSchemaOptional(field); + return { name, - description: field.description, - required: !field.isOptional() - }) - ); + description, + required: !isOptional + }; + }); +} + +function getMethodValue(schema: AnyObjectSchema): string { + const shape = getObjectShape(schema); + const methodSchema = shape?.method as AnySchema | undefined; + if (!methodSchema) { + throw new Error('Schema is missing a method literal'); + } + + // Extract literal value - works for both v3 and v4 + const value = getLiteralValue(methodSchema); + if (typeof value === 'string') { + return value; + } + + throw new Error('Schema method literal must be a string'); } function createCompletionResult(suggestions: string[]): CompleteResult { diff --git a/src/server/sse.test.ts b/src/server/sse.test.ts index 7dae26083..34ac071fe 100644 --- a/src/server/sse.test.ts +++ b/src/server/sse.test.ts @@ -5,7 +5,7 @@ import { SSEServerTransport } from './sse.js'; import { McpServer } from './mcp.js'; import { createServer, type Server } from 'node:http'; import { AddressInfo } from 'node:net'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { CallToolResult, JSONRPCMessage } from '../types.js'; const createMockResponse = () => { diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 8d78aad67..b5b169951 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -4,7 +4,7 @@ import { randomUUID } from 'node:crypto'; import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from './streamableHttp.js'; import { McpServer } from './mcp.js'; import { CallToolResult, JSONRPCMessage } from '../types.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { AuthInfo } from './auth/types.js'; async function getFreePort() { diff --git a/src/server/title.test.ts b/src/server/title.test.ts index 7f0feedc8..0f588514d 100644 --- a/src/server/title.test.ts +++ b/src/server/title.test.ts @@ -1,7 +1,7 @@ import { Server } from './index.js'; import { Client } from '../client/index.js'; import { InMemoryTransport } from '../inMemory.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; import { McpServer, ResourceTemplate } from './mcp.js'; describe('Title field backwards compatibility', () => { diff --git a/src/server/v3/completable.v3.test.ts b/src/server/v3/completable.v3.test.ts new file mode 100644 index 000000000..111874e1e --- /dev/null +++ b/src/server/v3/completable.v3.test.ts @@ -0,0 +1,54 @@ +import * as z from 'zod/v3'; +import { completable, getCompleter } from '../completable.js'; + +describe('completable', () => { + it('preserves types and values of underlying schema', () => { + const baseSchema = z.string(); + const schema = completable(baseSchema, () => []); + + expect(schema.parse('test')).toBe('test'); + expect(() => schema.parse(123)).toThrow(); + }); + + it('provides access to completion function', async () => { + const completions = ['foo', 'bar', 'baz']; + const schema = completable(z.string(), () => completions); + + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!('')).toEqual(completions); + }); + + it('allows async completion functions', async () => { + const completions = ['foo', 'bar', 'baz']; + const schema = completable(z.string(), async () => completions); + + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!('')).toEqual(completions); + }); + + it('passes current value to completion function', async () => { + const schema = completable(z.string(), value => [value + '!']); + + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!('test')).toEqual(['test!']); + }); + + it('works with number schemas', async () => { + const schema = completable(z.number(), () => [1, 2, 3]); + + expect(schema.parse(1)).toBe(1); + const completer = getCompleter(schema); + expect(completer).toBeDefined(); + expect(await completer!(0)).toEqual([1, 2, 3]); + }); + + it('preserves schema description', () => { + const desc = 'test description'; + const schema = completable(z.string().describe(desc), () => []); + + expect(schema.description).toBe(desc); + }); +}); diff --git a/src/server/v3/index.v3.test.ts b/src/server/v3/index.v3.test.ts new file mode 100644 index 000000000..bcd05b588 --- /dev/null +++ b/src/server/v3/index.v3.test.ts @@ -0,0 +1,964 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +import * as z from 'zod/v3'; +import { Client } from '../../client/index.js'; +import { InMemoryTransport } from '../../inMemory.js'; +import type { Transport } from '../../shared/transport.js'; +import { + CreateMessageRequestSchema, + ElicitRequestSchema, + ErrorCode, + LATEST_PROTOCOL_VERSION, + ListPromptsRequestSchema, + ListResourcesRequestSchema, + ListToolsRequestSchema, + type LoggingMessageNotification, + NotificationSchema, + RequestSchema, + ResultSchema, + SetLevelRequestSchema, + SUPPORTED_PROTOCOL_VERSIONS +} from '../../types.js'; +import { Server } from '../index.js'; +import { AnyObjectSchema } from '../zod-compat.js'; + +test('should accept latest protocol version', async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise(resolve => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + send: vi.fn().mockImplementation(message => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: 'test server', + version: '1.0' + }, + instructions: 'Test instructions' + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }) + }; + + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + instructions: 'Test instructions' + } + ); + + await server.connect(serverTransport); + + // Simulate initialize request with latest version + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: {}, + clientInfo: { + name: 'test client', + version: '1.0' + } + } + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test('should accept supported older protocol version', async () => { + const OLD_VERSION = SUPPORTED_PROTOCOL_VERSIONS[1]; + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise(resolve => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + send: vi.fn().mockImplementation(message => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: OLD_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: 'test server', + version: '1.0' + } + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }) + }; + + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + await server.connect(serverTransport); + + // Simulate initialize request with older version + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: OLD_VERSION, + capabilities: {}, + clientInfo: { + name: 'test client', + version: '1.0' + } + } + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test('should handle unsupported protocol version', async () => { + let sendPromiseResolve: (value: unknown) => void; + const sendPromise = new Promise(resolve => { + sendPromiseResolve = resolve; + }); + + const serverTransport: Transport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + send: vi.fn().mockImplementation(message => { + if (message.id === 1 && message.result) { + expect(message.result).toEqual({ + protocolVersion: LATEST_PROTOCOL_VERSION, + capabilities: expect.any(Object), + serverInfo: { + name: 'test server', + version: '1.0' + } + }); + sendPromiseResolve(undefined); + } + return Promise.resolve(); + }) + }; + + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + await server.connect(serverTransport); + + // Simulate initialize request with unsupported version + serverTransport.onmessage?.({ + jsonrpc: '2.0', + id: 1, + method: 'initialize', + params: { + protocolVersion: 'invalid-version', + capabilities: {}, + clientInfo: { + name: 'test client', + version: '1.0' + } + } + }); + + await expect(sendPromise).resolves.toBeUndefined(); +}); + +test('should respect client capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + // Implement request handler for sampling/createMessage + client.setRequestHandler(CreateMessageRequestSchema, async _request => { + // Mock implementation of createMessage + return { + model: 'test-model', + role: 'assistant', + content: { + type: 'text', + text: 'This is a test response' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(server.getClientCapabilities()).toEqual({ sampling: {} }); + + // This should work because sampling is supported by the client + await expect( + server.createMessage({ + messages: [], + maxTokens: 10 + }) + ).resolves.not.toThrow(); + + // This should still throw because roots are not supported by the client + await expect(server.listRoots()).rejects.toThrow(/^Client does not support/); +}); + +test('should respect client elicitation capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + client.setRequestHandler(ElicitRequestSchema, params => ({ + action: 'accept', + content: { + username: params.params.message.includes('username') ? 'test-user' : undefined, + confirmed: true + } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(server.getClientCapabilities()).toEqual({ elicitation: { form: {} } }); + + // This should work because elicitation is supported by the client + await expect( + server.elicitInput({ + message: 'Please provide your username', + requestedSchema: { + type: 'object', + properties: { + username: { + type: 'string', + title: 'Username', + description: 'Your username' + }, + confirmed: { + type: 'boolean', + title: 'Confirm', + description: 'Please confirm', + default: false + } + }, + required: ['username'] + } + }) + ).resolves.toEqual({ + action: 'accept', + content: { + username: 'test-user', + confirmed: true + } + }); + + // This should still throw because sampling is not supported by the client + await expect( + server.createMessage({ + messages: [], + maxTokens: 10 + }) + ).rejects.toThrow(/^Client does not support/); +}); + +test('should validate elicitation response against requested schema', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Set up client to return valid response + client.setRequestHandler(ElicitRequestSchema, _request => ({ + action: 'accept', + content: { + name: 'John Doe', + email: 'john@example.com', + age: 30 + } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Test with valid response + await expect( + server.elicitInput({ + message: 'Please provide your information', + requestedSchema: { + type: 'object', + properties: { + name: { + type: 'string', + minLength: 1 + }, + email: { + type: 'string', + minLength: 1 + }, + age: { + type: 'integer', + minimum: 0, + maximum: 150 + } + }, + required: ['name', 'email'] + } + }) + ).resolves.toEqual({ + action: 'accept', + content: { + name: 'John Doe', + email: 'john@example.com', + age: 30 + } + }); +}); + +test('should reject elicitation response with invalid data', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + // Set up client to return invalid response (missing required field, invalid age) + client.setRequestHandler(ElicitRequestSchema, _request => ({ + action: 'accept', + content: { + email: '', // Invalid - too short + age: -5 // Invalid age + } + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Test with invalid response + await expect( + server.elicitInput({ + message: 'Please provide your information', + requestedSchema: { + type: 'object', + properties: { + name: { + type: 'string', + minLength: 1 + }, + email: { + type: 'string', + minLength: 1 + }, + age: { + type: 'integer', + minimum: 0, + maximum: 150 + } + }, + required: ['name', 'email'] + } + }) + ).rejects.toThrow(/does not match requested schema/); +}); + +test('should allow elicitation reject and cancel without validation', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + + let requestCount = 0; + client.setRequestHandler(ElicitRequestSchema, _request => { + requestCount++; + if (requestCount === 1) { + return { action: 'decline' }; + } else { + return { action: 'cancel' }; + } + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + const schema = { + type: 'object' as const, + properties: { + name: { type: 'string' as const } + }, + required: ['name'] + }; + + // Test reject - should not validate + await expect( + server.elicitInput({ + message: 'Please provide your name', + requestedSchema: schema + }) + ).resolves.toEqual({ + action: 'decline' + }); + + // Test cancel - should not validate + await expect( + server.elicitInput({ + message: 'Please provide your name', + requestedSchema: schema + }) + ).resolves.toEqual({ + action: 'cancel' + }); +}); + +test('should respect server notification capabilities', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const [_clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await server.connect(serverTransport); + + // This should work because logging is supported by the server + await expect( + server.sendLoggingMessage({ + level: 'info', + data: 'Test log message' + }) + ).resolves.not.toThrow(); + + // This should throw because resource notificaitons are not supported by the server + await expect(server.sendResourceUpdated({ uri: 'test://resource' })).rejects.toThrow(/^Server does not support/); +}); + +test('should only allow setRequestHandler for declared capabilities', () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {} + } + } + ); + + // These should work because the capabilities are declared + expect(() => { + server.setRequestHandler(ListPromptsRequestSchema, () => ({ prompts: [] })); + }).not.toThrow(); + + expect(() => { + server.setRequestHandler(ListResourcesRequestSchema, () => ({ + resources: [] + })); + }).not.toThrow(); + + // These should throw because the capabilities are not declared + expect(() => { + server.setRequestHandler(ListToolsRequestSchema, () => ({ tools: [] })); + }).toThrow(/^Server does not support tools/); + + expect(() => { + server.setRequestHandler(SetLevelRequestSchema, () => ({})); + }).toThrow(/^Server does not support logging/); +}); + +/* + Test that custom request/notification/result schemas can be used with the Server class. + */ +test('should typecheck', () => { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const GetWeatherRequestSchema = (RequestSchema as unknown as z.ZodObject).extend({ + method: z.literal('weather/get'), + params: z.object({ + city: z.string() + }) + }) as AnyObjectSchema; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const GetForecastRequestSchema = (RequestSchema as unknown as z.ZodObject).extend({ + method: z.literal('weather/forecast'), + params: z.object({ + city: z.string(), + days: z.number() + }) + }) as AnyObjectSchema; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const WeatherForecastNotificationSchema = (NotificationSchema as unknown as z.ZodObject).extend({ + method: z.literal('weather/alert'), + params: z.object({ + severity: z.enum(['warning', 'watch']), + message: z.string() + }) + }) as AnyObjectSchema; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const WeatherRequestSchema = (GetWeatherRequestSchema as unknown as z.ZodObject).or( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + GetForecastRequestSchema as unknown as z.ZodObject + ) as AnyObjectSchema; + const WeatherNotificationSchema = WeatherForecastNotificationSchema as AnyObjectSchema; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const WeatherResultSchema = (ResultSchema as unknown as z.ZodObject).extend({ + temperature: z.number(), + conditions: z.string() + }) as AnyObjectSchema; + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + type InferSchema = T extends z.ZodType ? Output : never; + type WeatherRequest = InferSchema; + type WeatherNotification = InferSchema; + type WeatherResult = InferSchema; + + // Create a typed Server for weather data + const weatherServer = new Server( + { + name: 'WeatherServer', + version: '1.0.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + } + } + ); + + // Typecheck that only valid weather requests/notifications/results are allowed + weatherServer.setRequestHandler(GetWeatherRequestSchema, _request => { + return { + temperature: 72, + conditions: 'sunny' + }; + }); + + weatherServer.setNotificationHandler(WeatherForecastNotificationSchema, notification => { + // Type assertion needed for v3/v4 schema mixing + const params = notification.params as { message: string; severity: 'warning' | 'watch' }; + console.log(`Weather alert: ${params.message}`); + }); +}); + +test('should handle server cancelling a request', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: {} + } + ); + + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + // Set up client to delay responding to createMessage + client.setRequestHandler(CreateMessageRequestSchema, async (_request, _extra) => { + await new Promise(resolve => setTimeout(resolve, 1000)); + return { + model: 'test', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Set up abort controller + const controller = new AbortController(); + + // Issue request but cancel it immediately + const createMessagePromise = server.createMessage( + { + messages: [], + maxTokens: 10 + }, + { + signal: controller.signal + } + ); + controller.abort('Cancelled by test'); + + // Request should be rejected + await expect(createMessagePromise).rejects.toBe('Cancelled by test'); +}); + +test('should handle request timeout', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: {} + } + ); + + // Set up client that delays responses + const client = new Client( + { + name: 'test client', + version: '1.0' + }, + { + capabilities: { + sampling: {} + } + } + ); + + client.setRequestHandler(CreateMessageRequestSchema, async (_request, extra) => { + await new Promise((resolve, reject) => { + const timeout = setTimeout(resolve, 100); + extra.signal.addEventListener('abort', () => { + clearTimeout(timeout); + reject(extra.signal.reason); + }); + }); + + return { + model: 'test', + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Request with 0 msec timeout should fail immediately + await expect( + server.createMessage( + { + messages: [], + maxTokens: 10 + }, + { timeout: 0 } + ) + ).rejects.toMatchObject({ + code: ErrorCode.RequestTimeout + }); +}); + +/* + Test automatic log level handling for transports with and without sessionId + */ +test('should respect log level for transport without sessionId', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + expect(clientTransport.sessionId).toEqual(undefined); + + // Client sets logging level to warning + await client.setLoggingLevel('warning'); + + // This one will make it through + const warningParams: LoggingMessageNotification['params'] = { + level: 'warning', + logger: 'test server', + data: 'Warning message' + }; + + // This one will not + const debugParams: LoggingMessageNotification['params'] = { + level: 'debug', + logger: 'test server', + data: 'Debug message' + }; + + // Test the one that makes it through + clientTransport.onmessage = vi.fn().mockImplementation(message => { + expect(message).toEqual({ + jsonrpc: '2.0', + method: 'notifications/message', + params: warningParams + }); + }); + + // This one will not make it through + await server.sendLoggingMessage(debugParams); + expect(clientTransport.onmessage).not.toHaveBeenCalled(); + + // This one will, triggering the above test in clientTransport.onmessage + await server.sendLoggingMessage(warningParams); + expect(clientTransport.onmessage).toHaveBeenCalled(); +}); + +test('should respect log level for transport with sessionId', async () => { + const server = new Server( + { + name: 'test server', + version: '1.0' + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {} + }, + enforceStrictCapabilities: true + } + ); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + // Add a session id to the transports + const SESSION_ID = 'test-session-id'; + clientTransport.sessionId = SESSION_ID; + serverTransport.sessionId = SESSION_ID; + + expect(clientTransport.sessionId).toBeDefined(); + expect(serverTransport.sessionId).toBeDefined(); + + await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]); + + // Client sets logging level to warning + await client.setLoggingLevel('warning'); + + // This one will make it through + const warningParams: LoggingMessageNotification['params'] = { + level: 'warning', + logger: 'test server', + data: 'Warning message' + }; + + // This one will not + const debugParams: LoggingMessageNotification['params'] = { + level: 'debug', + logger: 'test server', + data: 'Debug message' + }; + + // Test the one that makes it through + clientTransport.onmessage = vi.fn().mockImplementation(message => { + expect(message).toEqual({ + jsonrpc: '2.0', + method: 'notifications/message', + params: warningParams + }); + }); + + // This one will not make it through + await server.sendLoggingMessage(debugParams, SESSION_ID); + expect(clientTransport.onmessage).not.toHaveBeenCalled(); + + // This one will, triggering the above test in clientTransport.onmessage + await server.sendLoggingMessage(warningParams, SESSION_ID); + expect(clientTransport.onmessage).toHaveBeenCalled(); +}); diff --git a/src/server/v3/mcp.v3.test.ts b/src/server/v3/mcp.v3.test.ts new file mode 100644 index 000000000..8348906d1 --- /dev/null +++ b/src/server/v3/mcp.v3.test.ts @@ -0,0 +1,4519 @@ +import * as z from 'zod/v3'; +import { Client } from '../../client/index.js'; +import { InMemoryTransport } from '../../inMemory.js'; +import { getDisplayName } from '../../shared/metadataUtils.js'; +import { UriTemplate } from '../../shared/uriTemplate.js'; +import { + CallToolResultSchema, + CompleteResultSchema, + ElicitRequestSchema, + GetPromptResultSchema, + ListPromptsResultSchema, + ListResourcesResultSchema, + ListResourceTemplatesResultSchema, + ListToolsResultSchema, + LoggingMessageNotificationSchema, + type Notification, + ReadResourceResultSchema, + type TextContent +} from '../../types.js'; +import { completable } from '../completable.js'; +import { McpServer, ResourceTemplate } from '../mcp.js'; + +describe('McpServer', () => { + /*** + * Test: Basic Server Instance + */ + test('should expose underlying Server instance', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + expect(mcpServer.server).toBeDefined(); + }); + + /*** + * Test: Notification Sending via Server + */ + test('should allow sending notifications via Server', async () => { + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { capabilities: { logging: {} } } + ); + + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // This should work because we're using the underlying server + await expect( + mcpServer.server.sendLoggingMessage({ + level: 'info', + data: 'Test log message' + }) + ).resolves.not.toThrow(); + + expect(notifications).toMatchObject([ + { + method: 'notifications/message', + params: { + level: 'info', + data: 'Test log message' + } + } + ]); + }); + + /*** + * Test: Progress Notification with Message Field + */ + test('should send progress notifications with message field', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // Create a tool that sends progress updates + mcpServer.tool( + 'long-operation', + 'A long running operation with progress updates', + { + steps: z.number().min(1).describe('Number of steps to perform') + }, + async ({ steps }, { sendNotification, _meta }) => { + const progressToken = _meta?.progressToken; + + if (progressToken) { + // Send progress notification for each step + for (let i = 1; i <= steps; i++) { + await sendNotification({ + method: 'notifications/progress', + params: { + progressToken, + progress: i, + total: steps, + message: `Completed step ${i} of ${steps}` + } + }); + } + } + + return { + content: [ + { + type: 'text' as const, + text: `Operation completed with ${steps} steps` + } + ] + }; + } + ); + + const progressUpdates: Array<{ + progress: number; + total?: number; + message?: string; + }> = []; + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool with progress tracking + await client.request( + { + method: 'tools/call', + params: { + name: 'long-operation', + arguments: { steps: 3 }, + _meta: { + progressToken: 'progress-test-1' + } + } + }, + CallToolResultSchema, + { + onprogress: progress => { + progressUpdates.push(progress); + } + } + ); + + // Verify progress notifications were received with message field + expect(progressUpdates).toHaveLength(3); + expect(progressUpdates[0]).toMatchObject({ + progress: 1, + total: 3, + message: 'Completed step 1 of 3' + }); + expect(progressUpdates[1]).toMatchObject({ + progress: 2, + total: 3, + message: 'Completed step 2 of 3' + }); + expect(progressUpdates[2]).toMatchObject({ + progress: 3, + total: 3, + message: 'Completed step 3 of 3' + }); + }); +}); + +describe('ResourceTemplate', () => { + /*** + * Test: ResourceTemplate Creation with String Pattern + */ + test('should create ResourceTemplate with string pattern', () => { + const template = new ResourceTemplate('test://{category}/{id}', { + list: undefined + }); + expect(template.uriTemplate.toString()).toBe('test://{category}/{id}'); + expect(template.listCallback).toBeUndefined(); + }); + + /*** + * Test: ResourceTemplate Creation with UriTemplate Instance + */ + test('should create ResourceTemplate with UriTemplate', () => { + const uriTemplate = new UriTemplate('test://{category}/{id}'); + const template = new ResourceTemplate(uriTemplate, { list: undefined }); + expect(template.uriTemplate).toBe(uriTemplate); + expect(template.listCallback).toBeUndefined(); + }); + + /*** + * Test: ResourceTemplate with List Callback + */ + test('should create ResourceTemplate with list callback', async () => { + const list = vi.fn().mockResolvedValue({ + resources: [{ name: 'Test', uri: 'test://example' }] + }); + + const template = new ResourceTemplate('test://{id}', { list }); + expect(template.listCallback).toBe(list); + + const abortController = new AbortController(); + const result = await template.listCallback?.({ + signal: abortController.signal, + requestId: 'not-implemented', + sendRequest: () => { + throw new Error('Not implemented'); + }, + sendNotification: () => { + throw new Error('Not implemented'); + } + }); + expect(result?.resources).toHaveLength(1); + expect(list).toHaveBeenCalled(); + }); +}); + +describe('tool()', () => { + afterEach(() => { + vi.restoreAllMocks(); + }); + + /*** + * Test: Zero-Argument Tool Registration + */ + test('should register zero-argument tool', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].inputSchema).toEqual({ + type: 'object', + properties: {} + }); + + // Adding the tool before the connection was established means no notification was sent + expect(notifications).toHaveLength(0); + + // Adding another tool triggers the update notification + mcpServer.tool('test2', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([ + { + method: 'notifications/tools/list_changed' + } + ]); + }); + + /*** + * Test: Updating Existing Tool + */ + test('should update existing tool', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial tool + const tool = mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Initial response' + } + ] + })); + + // Update the tool + tool.update({ + callback: async () => ({ + content: [ + { + type: 'text', + text: 'Updated response' + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the tool and verify we get the updated response + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'test' + } + }, + CallToolResultSchema + ); + + expect(result.content).toEqual([ + { + type: 'text', + text: 'Updated response' + } + ]); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Updating Tool with Schema + */ + test('should update tool with schema', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial tool + const tool = mcpServer.tool( + 'test', + { + name: z.string() + }, + async ({ name }) => ({ + content: [ + { + type: 'text', + text: `Initial: ${name}` + } + ] + }) + ); + + // Update the tool with a different schema + tool.update({ + paramsSchema: { + name: z.string(), + value: z.number() + }, + callback: async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `Updated: ${name}, ${value}` + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify the schema was updated + const listResult = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(listResult.tools[0].inputSchema).toMatchObject({ + properties: { + name: { type: 'string' }, + value: { type: 'number' } + } + }); + + // Call the tool with the new schema + const callResult = await client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + name: 'test', + value: 42 + } + } + }, + CallToolResultSchema + ); + + expect(callResult.content).toEqual([ + { + type: 'text', + text: 'Updated: test, 42' + } + ]); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Tool List Changed Notifications + */ + test('should send tool list changed notifications when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial tool + const tool = mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + expect(notifications).toHaveLength(0); + + // Now update the tool + tool.update({ + callback: async () => ({ + content: [ + { + type: 'text', + text: 'Updated response' + } + ] + }) + }); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([{ method: 'notifications/tools/list_changed' }]); + + // Now delete the tool + tool.remove(); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([ + { method: 'notifications/tools/list_changed' }, + { method: 'notifications/tools/list_changed' } + ]); + }); + + /*** + * Test: Tool Registration with Parameters + */ + test('should register tool with params', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // old api + mcpServer.tool( + 'test', + { + name: z.string(), + value: z.number() + }, + async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `${name}: ${value}` + } + ] + }) + ); + + // new api + mcpServer.registerTool( + 'test (new api)', + { + inputSchema: { name: z.string(), value: z.number() } + }, + async ({ name, value }) => ({ + content: [{ type: 'text', text: `${name}: ${value}` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: { + name: { type: 'string' }, + value: { type: 'number' } + } + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + }); + + /*** + * Test: Tool Registration with Description + */ + test('should register tool with description', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // old api + mcpServer.tool('test', 'Test description', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + // new api + mcpServer.registerTool( + 'test (new api)', + { + description: 'Test description' + }, + async () => ({ + content: [ + { + type: 'text' as const, + text: 'Test response' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].description).toBe('Test description'); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].description).toBe('Test description'); + }); + + /*** + * Test: Tool Registration with Annotations + */ + test('should register tool with annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('test', { title: 'Test Tool', readOnlyHint: true }, async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + mcpServer.registerTool( + 'test (new api)', + { + annotations: { title: 'Test Tool', readOnlyHint: true } + }, + async () => ({ + content: [ + { + type: 'text' as const, + text: 'Test response' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].annotations).toEqual({ + title: 'Test Tool', + readOnlyHint: true + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].annotations).toEqual({ + title: 'Test Tool', + readOnlyHint: true + }); + }); + + /*** + * Test: Tool Registration with Parameters and Annotations + */ + test('should register tool with params and annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('test', { name: z.string() }, { title: 'Test Tool', readOnlyHint: true }, async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + })); + + mcpServer.registerTool( + 'test (new api)', + { + inputSchema: { name: z.string() }, + annotations: { title: 'Test Tool', readOnlyHint: true } + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: { name: { type: 'string' } } + }); + expect(result.tools[0].annotations).toEqual({ + title: 'Test Tool', + readOnlyHint: true + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); + }); + + /*** + * Test: Tool Registration with Description, Parameters, and Annotations + */ + test('should register tool with description, params, and annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + 'A tool with everything', + { name: z.string() }, + { title: 'Complete Test Tool', readOnlyHint: true, openWorldHint: false }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + mcpServer.registerTool( + 'test (new api)', + { + description: 'A tool with everything', + inputSchema: { name: z.string() }, + annotations: { + title: 'Complete Test Tool', + readOnlyHint: true, + openWorldHint: false + } + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].description).toBe('A tool with everything'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: { name: { type: 'string' } } + }); + expect(result.tools[0].annotations).toEqual({ + title: 'Complete Test Tool', + readOnlyHint: true, + openWorldHint: false + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].description).toBe('A tool with everything'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); + }); + + /*** + * Test: Tool Registration with Description, Empty Parameters, and Annotations + */ + test('should register tool with description, empty params, and annotations', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + 'A tool with everything but empty params', + {}, + { + title: 'Complete Test Tool with empty params', + readOnlyHint: true, + openWorldHint: false + }, + async () => ({ + content: [{ type: 'text', text: 'Test response' }] + }) + ); + + mcpServer.registerTool( + 'test (new api)', + { + description: 'A tool with everything but empty params', + inputSchema: {}, + annotations: { + title: 'Complete Test Tool with empty params', + readOnlyHint: true, + openWorldHint: false + } + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Test response' }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(2); + expect(result.tools[0].name).toBe('test'); + expect(result.tools[0].description).toBe('A tool with everything but empty params'); + expect(result.tools[0].inputSchema).toMatchObject({ + type: 'object', + properties: {} + }); + expect(result.tools[0].annotations).toEqual({ + title: 'Complete Test Tool with empty params', + readOnlyHint: true, + openWorldHint: false + }); + expect(result.tools[1].name).toBe('test (new api)'); + expect(result.tools[1].description).toBe('A tool with everything but empty params'); + expect(result.tools[1].inputSchema).toEqual(result.tools[0].inputSchema); + expect(result.tools[1].annotations).toEqual(result.tools[0].annotations); + }); + + /*** + * Test: Tool Argument Validation + */ + test('should validate tool args', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + { + name: z.string(), + value: z.number() + }, + async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `${name}: ${value}` + } + ] + }) + ); + + mcpServer.registerTool( + 'test (new api)', + { + inputSchema: { + name: z.string(), + value: z.number() + } + }, + async ({ name, value }) => ({ + content: [ + { + type: 'text', + text: `${name}: ${value}` + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + name: 'test', + value: 'not a number' + } + } + }, + CallToolResultSchema + ); + + expect(result.isError).toBe(true); + expect(result.content).toEqual( + expect.arrayContaining([ + { + type: 'text', + text: expect.stringContaining('Input validation error: Invalid arguments for tool test') + } + ]) + ); + + const result2 = await client.request( + { + method: 'tools/call', + params: { + name: 'test (new api)', + arguments: { + name: 'test', + value: 'not a number' + } + } + }, + CallToolResultSchema + ); + + expect(result2.isError).toBe(true); + expect(result2.content).toEqual( + expect.arrayContaining([ + { + type: 'text', + text: expect.stringContaining('Input validation error: Invalid arguments for tool test (new api)') + } + ]) + ); + }); + + /*** + * Test: Preventing Duplicate Tool Registration + */ + test('should prevent duplicate tool registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + expect(() => { + mcpServer.tool('test', async () => ({ + content: [ + { + type: 'text', + text: 'Test response 2' + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Multiple Tool Registration + */ + test('should allow registering multiple tools', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.tool('tool1', () => ({ content: [] })); + + // This should also succeed and not throw about request handlers + mcpServer.tool('tool2', () => ({ content: [] })); + }); + + /*** + * Test: Tool with Output Schema and Structured Content + */ + test('should support tool with outputSchema and structuredContent', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with outputSchema + mcpServer.registerTool( + 'test', + { + description: 'Test tool with structured output', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string(), + timestamp: z.string() + } + }, + async ({ input }) => ({ + structuredContent: { + processedInput: input, + resultType: 'structured', + timestamp: '2023-01-01T00:00:00Z' + }, + content: [ + { + type: 'text', + text: JSON.stringify({ + processedInput: input, + resultType: 'structured', + timestamp: '2023-01-01T00:00:00Z' + }) + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Verify the tool registration includes outputSchema + const listResult = await client.request( + { + method: 'tools/list' + }, + ListToolsResultSchema + ); + + expect(listResult.tools).toHaveLength(1); + expect(listResult.tools[0].outputSchema).toMatchObject({ + type: 'object', + properties: { + processedInput: { type: 'string' }, + resultType: { type: 'string' }, + timestamp: { type: 'string' } + }, + required: ['processedInput', 'resultType', 'timestamp'] + }); + + // Call the tool and verify it returns valid structuredContent + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + input: 'hello' + } + } + }, + CallToolResultSchema + ); + + expect(result.structuredContent).toBeDefined(); + const structuredContent = result.structuredContent as { + processedInput: string; + resultType: string; + timestamp: string; + }; + expect(structuredContent.processedInput).toBe('hello'); + expect(structuredContent.resultType).toBe('structured'); + expect(structuredContent.timestamp).toBe('2023-01-01T00:00:00Z'); + + // For backward compatibility, content is auto-generated from structuredContent + expect(result.content).toBeDefined(); + expect(result.content!).toHaveLength(1); + expect(result.content![0]).toMatchObject({ type: 'text' }); + const textContent = result.content![0] as TextContent; + expect(JSON.parse(textContent.text)).toEqual(result.structuredContent); + }); + + /*** + * Test: Tool with Output Schema Must Provide Structured Content + */ + test('should throw error when tool with outputSchema returns no structuredContent', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with outputSchema that returns only content without structuredContent + mcpServer.registerTool( + 'test', + { + description: 'Test tool with output schema but missing structured content', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string() + } + }, + async ({ input }) => ({ + // Only return content without structuredContent + content: [ + { + type: 'text', + text: `Processed: ${input}` + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool and expect it to throw an error + const result = await client.callTool({ + name: 'test', + arguments: { + input: 'hello' + } + }); + + expect(result.isError).toBe(true); + expect(result.content).toEqual( + expect.arrayContaining([ + { + type: 'text', + text: expect.stringContaining( + 'Output validation error: Tool test has an output schema but no structured content was provided' + ) + } + ]) + ); + }); + /*** + * Test: Tool with Output Schema Must Provide Structured Content + */ + test('should skip outputSchema validation when isError is true', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerTool( + 'test', + { + description: 'Test tool with output schema but missing structured content', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string() + } + }, + async ({ input }) => ({ + content: [ + { + type: 'text', + text: `Processed: ${input}` + } + ], + isError: true + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.callTool({ + name: 'test', + arguments: { + input: 'hello' + } + }) + ).resolves.toStrictEqual({ + content: [ + { + type: 'text', + text: `Processed: hello` + } + ], + isError: true + }); + }); + + /*** + * Test: Schema Validation Failure for Invalid Structured Content + */ + test('should fail schema validation when tool returns invalid structuredContent', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Register a tool with outputSchema that returns invalid data + mcpServer.registerTool( + 'test', + { + description: 'Test tool with invalid structured output', + inputSchema: { + input: z.string() + }, + outputSchema: { + processedInput: z.string(), + resultType: z.string(), + timestamp: z.string() + } + }, + async ({ input }) => ({ + content: [ + { + type: 'text', + text: JSON.stringify({ + processedInput: input, + resultType: 'structured', + // Missing required 'timestamp' field + someExtraField: 'unexpected' // Extra field not in schema + }) + } + ], + structuredContent: { + processedInput: input, + resultType: 'structured', + // Missing required 'timestamp' field + someExtraField: 'unexpected' // Extra field not in schema + } + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool and expect it to throw a server-side validation error + const result = await client.callTool({ + name: 'test', + arguments: { + input: 'hello' + } + }); + + expect(result.isError).toBe(true); + expect(result.content).toEqual( + expect.arrayContaining([ + { + type: 'text', + text: expect.stringContaining('Output validation error: Invalid structured content for tool test') + } + ]) + ); + }); + + /*** + * Test: Pass Session ID to Tool Callback + */ + test('should pass sessionId to tool callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedSessionId: string | undefined; + mcpServer.tool('test-tool', async extra => { + receivedSessionId = extra.sessionId; + return { + content: [ + { + type: 'text', + text: 'Test response' + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + // Set a test sessionId on the server transport + serverTransport.sessionId = 'test-session-123'; + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await client.request( + { + method: 'tools/call', + params: { + name: 'test-tool' + } + }, + CallToolResultSchema + ); + + expect(receivedSessionId).toBe('test-session-123'); + }); + + /*** + * Test: Pass Request ID to Tool Callback + */ + test('should pass requestId to tool callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedRequestId: string | number | undefined; + mcpServer.tool('request-id-test', async extra => { + receivedRequestId = extra.requestId; + return { + content: [ + { + type: 'text', + text: `Received request ID: ${extra.requestId}` + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'request-id-test' + } + }, + CallToolResultSchema + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.content).toEqual( + expect.arrayContaining([ + { + type: 'text', + text: expect.stringContaining('Received request ID:') + } + ]) + ); + }); + + /*** + * Test: Send Notification within Tool Call + */ + test('should provide sendNotification within tool call', async () => { + const mcpServer = new McpServer( + { + name: 'test server', + version: '1.0' + }, + { capabilities: { logging: {} } } + ); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedLogMessage: string | undefined; + const loggingMessage = 'hello here is log message 1'; + + client.setNotificationHandler(LoggingMessageNotificationSchema, notification => { + receivedLogMessage = notification.params.data as string; + }); + + mcpServer.tool('test-tool', async ({ sendNotification }) => { + await sendNotification({ + method: 'notifications/message', + params: { level: 'debug', data: loggingMessage } + }); + return { + content: [ + { + type: 'text', + text: 'Test response' + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + await client.request( + { + method: 'tools/call', + params: { + name: 'test-tool' + } + }, + CallToolResultSchema + ); + expect(receivedLogMessage).toBe(loggingMessage); + }); + + /*** + * Test: Client to Server Tool Call + */ + test('should allow client to call server tools', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool( + 'test', + 'Test tool', + { + input: z.string() + }, + async ({ input }) => ({ + content: [ + { + type: 'text', + text: `Processed: ${input}` + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'test', + arguments: { + input: 'hello' + } + } + }, + CallToolResultSchema + ); + + expect(result.content).toEqual([ + { + type: 'text', + text: 'Processed: hello' + } + ]); + }); + + /*** + * Test: Graceful Tool Error Handling + */ + test('should handle server tool errors gracefully', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('error-test', async () => { + throw new Error('Tool execution failed'); + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'error-test' + } + }, + CallToolResultSchema + ); + + expect(result.isError).toBe(true); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Tool execution failed' + } + ]); + }); + + /*** + * Test: McpError for Invalid Tool Name + */ + test('should throw McpError for invalid tool name', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.tool('test-tool', async () => ({ + content: [ + { + type: 'text', + text: 'Test response' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'tools/call', + params: { + name: 'nonexistent-tool' + } + }, + CallToolResultSchema + ); + + expect(result.isError).toBe(true); + expect(result.content).toEqual( + expect.arrayContaining([ + { + type: 'text', + text: expect.stringContaining('Tool nonexistent-tool not found') + } + ]) + ); + }); + + /*** + * Test: Tool Registration with _meta field + */ + test('should register tool with _meta field and include it in list response', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + const metaData = { + author: 'test-author', + version: '1.2.3', + category: 'utility', + tags: ['test', 'example'] + }; + + mcpServer.registerTool( + 'test-with-meta', + { + description: 'A tool with _meta field', + inputSchema: { name: z.string() }, + _meta: metaData + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('test-with-meta'); + expect(result.tools[0].description).toBe('A tool with _meta field'); + expect(result.tools[0]._meta).toEqual(metaData); + }); + + /*** + * Test: Tool Registration without _meta field should have undefined _meta + */ + test('should register tool without _meta field and have undefined _meta in response', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerTool( + 'test-without-meta', + { + description: 'A tool without _meta field', + inputSchema: { name: z.string() } + }, + async ({ name }) => ({ + content: [{ type: 'text', text: `Hello, ${name}!` }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(1); + expect(result.tools[0].name).toBe('test-without-meta'); + expect(result.tools[0]._meta).toBeUndefined(); + }); + + test('should validate tool names according to SEP specification', () => { + // Create a new server instance for this test + const testServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // Spy on console.warn to verify warnings are logged + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + // Test valid tool names + testServer.registerTool( + 'valid-tool-name', + { + description: 'A valid tool name' + }, + async () => ({ content: [{ type: 'text' as const, text: 'Success' }] }) + ); + + // Test tool name with warnings (starts with dash) + testServer.registerTool( + '-warning-tool', + { + description: 'A tool name that generates warnings' + }, + async () => ({ content: [{ type: 'text' as const, text: 'Success' }] }) + ); + + // Test invalid tool name (contains spaces) + testServer.registerTool( + 'invalid tool name', + { + description: 'An invalid tool name' + }, + async () => ({ content: [{ type: 'text' as const, text: 'Success' }] }) + ); + + // Verify that warnings were issued (both for warnings and validation failures) + expect(warnSpy).toHaveBeenCalled(); + + // Verify specific warning content + const warningCalls = warnSpy.mock.calls.map(call => call.join(' ')); + expect(warningCalls.some(call => call.includes('Tool name starts or ends with a dash'))).toBe(true); + expect(warningCalls.some(call => call.includes('Tool name contains spaces'))).toBe(true); + expect(warningCalls.some(call => call.includes('Tool name contains invalid characters'))).toBe(true); + + // Clean up spies + warnSpy.mockRestore(); + }); +}); + +describe('resource()', () => { + /*** + * Test: Resource Registration with URI and Read Callback + */ + test('should register resource with uri and readCallback', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].name).toBe('test'); + expect(result.resources[0].uri).toBe('test://resource'); + }); + + /*** + * Test: Update Resource with URI + */ + test('should update resource with uri', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resource + const resource = mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Initial content' + } + ] + })); + + // Update the resource + resource.update({ + callback: async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Updated content' + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Read the resource and verify we get the updated content + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource' + } + }, + ReadResourceResultSchema + ); + + expect(result.contents).toHaveLength(1); + expect(result.contents).toEqual( + expect.arrayContaining([ + { + text: expect.stringContaining('Updated content'), + uri: 'test://resource' + } + ]) + ); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Update Resource Template + */ + test('should update resource template', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resource template + const resourceTemplate = mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { list: undefined }), + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Initial content' + } + ] + }) + ); + + // Update the resource template + resourceTemplate.update({ + callback: async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Updated content' + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Read the resource and verify we get the updated content + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource/123' + } + }, + ReadResourceResultSchema + ); + + expect(result.contents).toHaveLength(1); + expect(result.contents).toEqual( + expect.arrayContaining([ + { + text: expect.stringContaining('Updated content'), + uri: 'test://resource/123' + } + ]) + ); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Resource List Changed Notification + */ + test('should send resource list changed notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resource + const resource = mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + expect(notifications).toHaveLength(0); + + // Now update the resource while connected + resource.update({ + callback: async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Updated content' + } + ] + }) + }); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); + }); + + /*** + * Test: Remove Resource and Send Notification + */ + test('should remove resource and send notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial resources + const resource1 = mcpServer.resource('resource1', 'test://resource1', async () => ({ + contents: [{ uri: 'test://resource1', text: 'Resource 1 content' }] + })); + + mcpServer.resource('resource2', 'test://resource2', async () => ({ + contents: [{ uri: 'test://resource2', text: 'Resource 2 content' }] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify both resources are registered + let result = await client.request({ method: 'resources/list' }, ListResourcesResultSchema); + + expect(result.resources).toHaveLength(2); + + expect(notifications).toHaveLength(0); + + // Remove a resource + resource1.remove(); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + // Should have sent notification + expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); + + // Verify the resource was removed + result = await client.request({ method: 'resources/list' }, ListResourcesResultSchema); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].uri).toBe('test://resource2'); + }); + + /*** + * Test: Remove Resource Template and Send Notification + */ + test('should remove resource template and send notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register resource template + const resourceTemplate = mcpServer.resource( + 'template', + new ResourceTemplate('test://resource/{id}', { list: undefined }), + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Template content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify template is registered + const result = await client.request({ method: 'resources/templates/list' }, ListResourceTemplatesResultSchema); + + expect(result.resourceTemplates).toHaveLength(1); + expect(notifications).toHaveLength(0); + + // Remove the template + resourceTemplate.remove(); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + // Should have sent notification + expect(notifications).toMatchObject([{ method: 'notifications/resources/list_changed' }]); + + // Verify the template was removed + const result2 = await client.request({ method: 'resources/templates/list' }, ListResourceTemplatesResultSchema); + + expect(result2.resourceTemplates).toHaveLength(0); + }); + + /*** + * Test: Resource Registration with Metadata + */ + test('should register resource with metadata', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + 'test://resource', + { + description: 'Test resource', + mimeType: 'text/plain' + }, + async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(1); + expect(result.resources[0].description).toBe('Test resource'); + expect(result.resources[0].mimeType).toBe('text/plain'); + }); + + /*** + * Test: Resource Template Registration + */ + test('should register resource template', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('test', new ResourceTemplate('test://resource/{id}', { list: undefined }), async () => ({ + contents: [ + { + uri: 'test://resource/123', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/templates/list' + }, + ListResourceTemplatesResultSchema + ); + + expect(result.resourceTemplates).toHaveLength(1); + expect(result.resourceTemplates[0].name).toBe('test'); + expect(result.resourceTemplates[0].uriTemplate).toBe('test://resource/{id}'); + }); + + /*** + * Test: Resource Template with List Callback + */ + test('should register resource template with listCallback', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Resource 1', + uri: 'test://resource/1' + }, + { + name: 'Resource 2', + uri: 'test://resource/2' + } + ] + }) + }), + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(2); + expect(result.resources[0].name).toBe('Resource 1'); + expect(result.resources[0].uri).toBe('test://resource/1'); + expect(result.resources[1].name).toBe('Resource 2'); + expect(result.resources[1].uri).toBe('test://resource/2'); + }); + + /*** + * Test: Template Variables to Read Callback + */ + test('should pass template variables to readCallback', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}/{id}', { + list: undefined + }), + async (uri, { category, id }) => ({ + contents: [ + { + uri: uri.href, + text: `Category: ${category}, ID: ${id}` + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource/books/123' + } + }, + ReadResourceResultSchema + ); + + expect(result.contents).toEqual( + expect.arrayContaining([ + { + text: expect.stringContaining('Category: books, ID: 123'), + uri: 'test://resource/books/123' + } + ]) + ); + }); + + /*** + * Test: Preventing Duplicate Resource Registration + */ + test('should prevent duplicate resource registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + expect(() => { + mcpServer.resource('test2', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content 2' + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Multiple Resource Registration + */ + test('should allow registering multiple resources', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.resource('resource1', 'test://resource1', async () => ({ + contents: [ + { + uri: 'test://resource1', + text: 'Test content 1' + } + ] + })); + + // This should also succeed and not throw about request handlers + mcpServer.resource('resource2', 'test://resource2', async () => ({ + contents: [ + { + uri: 'test://resource2', + text: 'Test content 2' + } + ] + })); + }); + + /*** + * Test: Preventing Duplicate Resource Template Registration + */ + test('should prevent duplicate resource template registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.resource('test', new ResourceTemplate('test://resource/{id}', { list: undefined }), async () => ({ + contents: [ + { + uri: 'test://resource/123', + text: 'Test content' + } + ] + })); + + expect(() => { + mcpServer.resource('test', new ResourceTemplate('test://resource/{id}', { list: undefined }), async () => ({ + contents: [ + { + uri: 'test://resource/123', + text: 'Test content 2' + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Graceful Resource Read Error Handling + */ + test('should handle resource read errors gracefully', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('error-test', 'test://error', async () => { + throw new Error('Resource read failed'); + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'resources/read', + params: { + uri: 'test://error' + } + }, + ReadResourceResultSchema + ) + ).rejects.toThrow(/Resource read failed/); + }); + + /*** + * Test: McpError for Invalid Resource URI + */ + test('should throw McpError for invalid resource URI', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource('test', 'test://resource', async () => ({ + contents: [ + { + uri: 'test://resource', + text: 'Test content' + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'resources/read', + params: { + uri: 'test://nonexistent' + } + }, + ReadResourceResultSchema + ) + ).rejects.toThrow(/Resource test:\/\/nonexistent not found/); + }); + + /*** + * Test: Registering a resource template with a complete callback should update server capabilities to advertise support for completion + */ + test('should advertise support for completion when a resource template with a complete callback is defined', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: () => ['books', 'movies', 'music'] + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + expect(client.getServerCapabilities()).toMatchObject({ completions: {} }); + }); + + /*** + * Test: Resource Template Parameter Completion + */ + test('should support completion of resource template parameters', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: () => ['books', 'movies', 'music'] + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'test://resource/{category}' + }, + argument: { + name: 'category', + value: '' + } + } + }, + CompleteResultSchema + ); + + expect(result.completion.values).toEqual(['books', 'movies', 'music']); + expect(result.completion.total).toBe(3); + }); + + /*** + * Test: Filtered Resource Template Parameter Completion + */ + test('should support filtered completion of resource template parameters', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: (test: string) => ['books', 'movies', 'music'].filter(value => value.startsWith(test)) + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'test://resource/{category}' + }, + argument: { + name: 'category', + value: 'm' + } + } + }, + CompleteResultSchema + ); + + expect(result.completion.values).toEqual(['movies', 'music']); + expect(result.completion.total).toBe(2); + }); + + /*** + * Test: Pass Request ID to Resource Callback + */ + test('should pass requestId to resource callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedRequestId: string | number | undefined; + mcpServer.resource('request-id-test', 'test://resource', async (_uri, extra) => { + receivedRequestId = extra.requestId; + return { + contents: [ + { + uri: 'test://resource', + text: `Received request ID: ${extra.requestId}` + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/read', + params: { + uri: 'test://resource' + } + }, + ReadResourceResultSchema + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.contents).toEqual( + expect.arrayContaining([ + { + text: expect.stringContaining(`Received request ID:`), + uri: 'test://resource' + } + ]) + ); + }); +}); + +describe('prompt()', () => { + /*** + * Test: Zero-Argument Prompt Registration + */ + test('should register zero-argument prompt', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'prompts/list' + }, + ListPromptsResultSchema + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('test'); + expect(result.prompts[0].arguments).toBeUndefined(); + }); + /*** + * Test: Updating Existing Prompt + */ + test('should update existing prompt', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial prompt + const prompt = mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Initial response' + } + } + ] + })); + + // Update the prompt + prompt.update({ + callback: async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Updated response' + } + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Call the prompt and verify we get the updated response + const result = await client.request( + { + method: 'prompts/get', + params: { + name: 'test' + } + }, + GetPromptResultSchema + ); + + expect(result.messages).toHaveLength(1); + expect(result.messages).toEqual( + expect.arrayContaining([ + { + role: 'assistant', + content: { + type: 'text', + text: 'Updated response' + } + } + ]) + ); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Updating Prompt with Schema + */ + test('should update prompt with schema', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial prompt + const prompt = mcpServer.prompt( + 'test', + { + name: z.string() + }, + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Initial: ${name}` + } + } + ] + }) + ); + + // Update the prompt with a different schema + prompt.update({ + argsSchema: { + name: z.string(), + value: z.string() + }, + callback: async ({ name, value }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Updated: ${name}, ${value}` + } + } + ] + }) + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify the schema was updated + const listResult = await client.request( + { + method: 'prompts/list' + }, + ListPromptsResultSchema + ); + + expect(listResult.prompts[0].arguments).toHaveLength(2); + expect(listResult.prompts[0].arguments?.map(a => a.name).sort()).toEqual(['name', 'value']); + + // Call the prompt with the new schema + const getResult = await client.request( + { + method: 'prompts/get', + params: { + name: 'test', + arguments: { + name: 'test', + value: 'value' + } + } + }, + GetPromptResultSchema + ); + + expect(getResult.messages).toHaveLength(1); + expect(getResult.messages).toEqual( + expect.arrayContaining([ + { + role: 'assistant', + content: { + type: 'text', + text: 'Updated: test, value' + } + } + ]) + ); + + // Update happened before transport was connected, so no notifications should be expected + expect(notifications).toHaveLength(0); + }); + + /*** + * Test: Prompt List Changed Notification + */ + test('should send prompt list changed notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial prompt + const prompt = mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + expect(notifications).toHaveLength(0); + + // Now update the prompt while connected + prompt.update({ + callback: async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Updated response' + } + } + ] + }) + }); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + expect(notifications).toMatchObject([{ method: 'notifications/prompts/list_changed' }]); + }); + + /*** + * Test: Remove Prompt and Send Notification + */ + test('should remove prompt and send notification when connected', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const notifications: Notification[] = []; + const client = new Client({ + name: 'test client', + version: '1.0' + }); + client.fallbackNotificationHandler = async notification => { + notifications.push(notification); + }; + + // Register initial prompts + const prompt1 = mcpServer.prompt('prompt1', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Prompt 1 response' + } + } + ] + })); + + mcpServer.prompt('prompt2', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Prompt 2 response' + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + // Verify both prompts are registered + let result = await client.request({ method: 'prompts/list' }, ListPromptsResultSchema); + + expect(result.prompts).toHaveLength(2); + expect(result.prompts.map(p => p.name).sort()).toEqual(['prompt1', 'prompt2']); + + expect(notifications).toHaveLength(0); + + // Remove a prompt + prompt1.remove(); + + // Yield event loop to let the notification fly + await new Promise(process.nextTick); + + // Should have sent notification + expect(notifications).toMatchObject([{ method: 'notifications/prompts/list_changed' }]); + + // Verify the prompt was removed + result = await client.request({ method: 'prompts/list' }, ListPromptsResultSchema); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('prompt2'); + }); + + /*** + * Test: Prompt Registration with Arguments Schema + */ + test('should register prompt with args schema', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test', + { + name: z.string(), + value: z.string() + }, + async ({ name, value }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `${name}: ${value}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'prompts/list' + }, + ListPromptsResultSchema + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('test'); + expect(result.prompts[0].arguments).toEqual([ + { name: 'name', required: true }, + { name: 'value', required: true } + ]); + }); + + /*** + * Test: Prompt Registration with Description + */ + test('should register prompt with description', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt('test', 'Test description', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'prompts/list' + }, + ListPromptsResultSchema + ); + + expect(result.prompts).toHaveLength(1); + expect(result.prompts[0].name).toBe('test'); + expect(result.prompts[0].description).toBe('Test description'); + }); + + /*** + * Test: Prompt Argument Validation + */ + test('should validate prompt args', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test', + { + name: z.string(), + value: z.string().min(3) + }, + async ({ name, value }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `${name}: ${value}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'prompts/get', + params: { + name: 'test', + arguments: { + name: 'test', + value: 'ab' // Too short + } + } + }, + GetPromptResultSchema + ) + ).rejects.toThrow(/Invalid arguments/); + }); + + /*** + * Test: Preventing Duplicate Prompt Registration + */ + test('should prevent duplicate prompt registration', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); + + expect(() => { + mcpServer.prompt('test', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response 2' + } + } + ] + })); + }).toThrow(/already registered/); + }); + + /*** + * Test: Multiple Prompt Registration + */ + test('should allow registering multiple prompts', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.prompt('prompt1', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response 1' + } + } + ] + })); + + // This should also succeed and not throw about request handlers + mcpServer.prompt('prompt2', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response 2' + } + } + ] + })); + }); + + /*** + * Test: Prompt Registration with Arguments + */ + test('should allow registering prompts with arguments', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // This should succeed + mcpServer.prompt('echo', { message: z.string() }, ({ message }) => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please process this message: ${message}` + } + } + ] + })); + }); + + /*** + * Test: Resources and Prompts with Completion Handlers + */ + test('should allow registering both resources and prompts with completion handlers', () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + // Register a resource with completion + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{category}', { + list: undefined, + complete: { + category: () => ['books', 'movies', 'music'] + } + }), + async () => ({ + contents: [ + { + uri: 'test://resource/test', + text: 'Test content' + } + ] + }) + ); + + // Register a prompt with completion + mcpServer.prompt('echo', { message: completable(z.string(), () => ['hello', 'world']) }, ({ message }) => ({ + messages: [ + { + role: 'user', + content: { + type: 'text', + text: `Please process this message: ${message}` + } + } + ] + })); + }); + + /*** + * Test: McpError for Invalid Prompt Name + */ + test('should throw McpError for invalid prompt name', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt('test-prompt', async () => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: 'Test response' + } + } + ] + })); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + await expect( + client.request( + { + method: 'prompts/get', + params: { + name: 'nonexistent-prompt' + } + }, + GetPromptResultSchema + ) + ).rejects.toThrow(/Prompt nonexistent-prompt not found/); + }); + + /*** + * Test: Registering a prompt with a completable argument should update server capabilities to advertise support for completion + */ + test('should advertise support for completion when a prompt with a completable argument is defined', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test-prompt', + { + name: completable(z.string(), () => ['Alice', 'Bob', 'Charlie']) + }, + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + expect(client.getServerCapabilities()).toMatchObject({ completions: {} }); + }); + + /*** + * Test: Prompt Argument Completion + */ + test('should support completion of prompt arguments', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test-prompt', + { + name: completable(z.string(), () => ['Alice', 'Bob', 'Charlie']) + }, + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: '' + } + } + }, + CompleteResultSchema + ); + + expect(result.completion.values).toEqual(['Alice', 'Bob', 'Charlie']); + expect(result.completion.total).toBe(3); + }); + + /*** + * Test: Filtered Prompt Argument Completion + */ + test('should support filtered completion of prompt arguments', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.prompt( + 'test-prompt', + { + name: completable(z.string(), test => ['Alice', 'Bob', 'Charlie'].filter(value => value.startsWith(test))) + }, + async ({ name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'A' + } + } + }, + CompleteResultSchema + ); + + expect(result.completion.values).toEqual(['Alice']); + expect(result.completion.total).toBe(1); + }); + + /*** + * Test: Pass Request ID to Prompt Callback + */ + test('should pass requestId to prompt callback via RequestHandlerExtra', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + let receivedRequestId: string | number | undefined; + mcpServer.prompt('request-id-test', async extra => { + receivedRequestId = extra.requestId; + return { + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Received request ID: ${extra.requestId}` + } + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'prompts/get', + params: { + name: 'request-id-test' + } + }, + GetPromptResultSchema + ); + + expect(receivedRequestId).toBeDefined(); + expect(typeof receivedRequestId === 'string' || typeof receivedRequestId === 'number').toBe(true); + expect(result.messages).toEqual( + expect.arrayContaining([ + { + role: 'assistant', + content: { + type: 'text', + text: expect.stringContaining(`Received request ID:`) + } + } + ]) + ); + }); + + /*** + * Test: Resource Template Metadata Priority + */ + test('should prioritize individual resource metadata over template metadata', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Resource 1', + uri: 'test://resource/1', + description: 'Individual resource description', + mimeType: 'text/plain' + }, + { + name: 'Resource 2', + uri: 'test://resource/2' + // This resource has no description or mimeType + } + ] + }) + }), + { + description: 'Template description', + mimeType: 'application/json' + }, + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(2); + + // Resource 1 should have its own metadata + expect(result.resources[0].name).toBe('Resource 1'); + expect(result.resources[0].description).toBe('Individual resource description'); + expect(result.resources[0].mimeType).toBe('text/plain'); + + // Resource 2 should inherit template metadata + expect(result.resources[1].name).toBe('Resource 2'); + expect(result.resources[1].description).toBe('Template description'); + expect(result.resources[1].mimeType).toBe('application/json'); + }); + + /*** + * Test: Resource Template Metadata Overrides All Fields + */ + test('should allow resource to override all template metadata fields', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.resource( + 'test', + new ResourceTemplate('test://resource/{id}', { + list: async () => ({ + resources: [ + { + name: 'Overridden Name', + uri: 'test://resource/1', + description: 'Overridden description', + mimeType: 'text/markdown' + // Add any other metadata fields if they exist + } + ] + }) + }), + { + title: 'Template Name', + description: 'Template description', + mimeType: 'application/json' + }, + async uri => ({ + contents: [ + { + uri: uri.href, + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + const result = await client.request( + { + method: 'resources/list' + }, + ListResourcesResultSchema + ); + + expect(result.resources).toHaveLength(1); + + // All fields should be from the individual resource, not the template + expect(result.resources[0].name).toBe('Overridden Name'); + expect(result.resources[0].description).toBe('Overridden description'); + expect(result.resources[0].mimeType).toBe('text/markdown'); + }); +}); + +describe('Tool title precedence', () => { + test('should follow correct title precedence: title → annotations.title → name', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + // Tool 1: Only name + mcpServer.tool('tool_name_only', async () => ({ + content: [{ type: 'text', text: 'Response' }] + })); + + // Tool 2: Name and annotations.title + mcpServer.tool( + 'tool_with_annotations_title', + 'Tool with annotations title', + { + title: 'Annotations Title' + }, + async () => ({ + content: [{ type: 'text', text: 'Response' }] + }) + ); + + // Tool 3: Name and title (using registerTool) + mcpServer.registerTool( + 'tool_with_title', + { + title: 'Regular Title', + description: 'Tool with regular title' + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Response' }] + }) + ); + + // Tool 4: All three - title should win + mcpServer.registerTool( + 'tool_with_all_titles', + { + title: 'Regular Title Wins', + description: 'Tool with all titles', + annotations: { + title: 'Annotations Title Should Not Show' + } + }, + async () => ({ + content: [{ type: 'text' as const, text: 'Response' }] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await Promise.all([client.connect(clientTransport), mcpServer.connect(serverTransport)]); + + const result = await client.request({ method: 'tools/list' }, ListToolsResultSchema); + + expect(result.tools).toHaveLength(4); + + // Tool 1: Only name - should display name + const tool1 = result.tools.find(t => t.name === 'tool_name_only'); + expect(tool1).toBeDefined(); + expect(getDisplayName(tool1!)).toBe('tool_name_only'); + + // Tool 2: Name and annotations.title - should display annotations.title + const tool2 = result.tools.find(t => t.name === 'tool_with_annotations_title'); + expect(tool2).toBeDefined(); + expect(tool2!.annotations?.title).toBe('Annotations Title'); + expect(getDisplayName(tool2!)).toBe('Annotations Title'); + + // Tool 3: Name and title - should display title + const tool3 = result.tools.find(t => t.name === 'tool_with_title'); + expect(tool3).toBeDefined(); + expect(tool3!.title).toBe('Regular Title'); + expect(getDisplayName(tool3!)).toBe('Regular Title'); + + // Tool 4: All three - title should take precedence + const tool4 = result.tools.find(t => t.name === 'tool_with_all_titles'); + expect(tool4).toBeDefined(); + expect(tool4!.title).toBe('Regular Title Wins'); + expect(tool4!.annotations?.title).toBe('Annotations Title Should Not Show'); + expect(getDisplayName(tool4!)).toBe('Regular Title Wins'); + }); + + test('getDisplayName unit tests for title precedence', () => { + // Test 1: Only name + expect(getDisplayName({ name: 'tool_name' })).toBe('tool_name'); + + // Test 2: Name and title - title wins + expect( + getDisplayName({ + name: 'tool_name', + title: 'Tool Title' + }) + ).toBe('Tool Title'); + + // Test 3: Name and annotations.title - annotations.title wins + expect( + getDisplayName({ + name: 'tool_name', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + + // Test 4: All three - title wins (correct precedence) + expect( + getDisplayName({ + name: 'tool_name', + title: 'Regular Title', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Regular Title'); + + // Test 5: Empty title should not be used + expect( + getDisplayName({ + name: 'tool_name', + title: '', + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + + // Test 6: Undefined vs null handling + expect( + getDisplayName({ + name: 'tool_name', + title: undefined, + annotations: { title: 'Annotations Title' } + }) + ).toBe('Annotations Title'); + }); + + test('should support resource template completion with resolved context', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerResource( + 'test', + new ResourceTemplate('github://repos/{owner}/{repo}', { + list: undefined, + complete: { + repo: (value, context) => { + if (context?.arguments?.['owner'] === 'org1') { + return ['project1', 'project2', 'project3'].filter(r => r.startsWith(value)); + } else if (context?.arguments?.['owner'] === 'org2') { + return ['repo1', 'repo2', 'repo3'].filter(r => r.startsWith(value)); + } + return []; + } + } + }), + { + title: 'GitHub Repository', + description: 'Repository information' + }, + async () => ({ + contents: [ + { + uri: 'github://repos/test/test', + text: 'Test content' + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Test with microsoft owner + const result1 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 'p' + }, + context: { + arguments: { + owner: 'org1' + } + } + } + }, + CompleteResultSchema + ); + + expect(result1.completion.values).toEqual(['project1', 'project2', 'project3']); + expect(result1.completion.total).toBe(3); + + // Test with facebook owner + const result2 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 'r' + }, + context: { + arguments: { + owner: 'org2' + } + } + } + }, + CompleteResultSchema + ); + + expect(result2.completion.values).toEqual(['repo1', 'repo2', 'repo3']); + expect(result2.completion.total).toBe(3); + + // Test with no resolved context + const result3 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/resource', + uri: 'github://repos/{owner}/{repo}' + }, + argument: { + name: 'repo', + value: 't' + } + } + }, + CompleteResultSchema + ); + + expect(result3.completion.values).toEqual([]); + expect(result3.completion.total).toBe(0); + }); + + test('should support prompt argument completion with resolved context', async () => { + const mcpServer = new McpServer({ + name: 'test server', + version: '1.0' + }); + + const client = new Client({ + name: 'test client', + version: '1.0' + }); + + mcpServer.registerPrompt( + 'test-prompt', + { + title: 'Team Greeting', + description: 'Generate a greeting for team members', + argsSchema: { + department: completable(z.string(), value => { + return ['engineering', 'sales', 'marketing', 'support'].filter(d => d.startsWith(value)); + }), + name: completable(z.string(), (value, context) => { + const department = context?.arguments?.['department']; + if (department === 'engineering') { + return ['Alice', 'Bob', 'Charlie'].filter(n => n.startsWith(value)); + } else if (department === 'sales') { + return ['David', 'Eve', 'Frank'].filter(n => n.startsWith(value)); + } else if (department === 'marketing') { + return ['Grace', 'Henry', 'Iris'].filter(n => n.startsWith(value)); + } + return ['Guest'].filter(n => n.startsWith(value)); + }) + } + }, + async ({ department, name }) => ({ + messages: [ + { + role: 'assistant', + content: { + type: 'text', + text: `Hello ${name}, welcome to the ${department} team!` + } + } + ] + }) + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Test with engineering department + const result1 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'A' + }, + context: { + arguments: { + department: 'engineering' + } + } + } + }, + CompleteResultSchema + ); + + expect(result1.completion.values).toEqual(['Alice']); + + // Test with sales department + const result2 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'D' + }, + context: { + arguments: { + department: 'sales' + } + } + } + }, + CompleteResultSchema + ); + + expect(result2.completion.values).toEqual(['David']); + + // Test with marketing department + const result3 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'G' + }, + context: { + arguments: { + department: 'marketing' + } + } + } + }, + CompleteResultSchema + ); + + expect(result3.completion.values).toEqual(['Grace']); + + // Test with no resolved context + const result4 = await client.request( + { + method: 'completion/complete', + params: { + ref: { + type: 'ref/prompt', + name: 'test-prompt' + }, + argument: { + name: 'name', + value: 'G' + } + } + }, + CompleteResultSchema + ); + + expect(result4.completion.values).toEqual(['Guest']); + }); +}); + +describe('elicitInput()', () => { + const checkAvailability = vi.fn().mockResolvedValue(false); + const findAlternatives = vi.fn().mockResolvedValue([]); + const makeBooking = vi.fn().mockResolvedValue('BOOKING-123'); + + let mcpServer: McpServer; + let client: Client; + + beforeEach(() => { + vi.clearAllMocks(); + + // Create server with restaurant booking tool + mcpServer = new McpServer({ + name: 'restaurant-booking-server', + version: '1.0.0' + }); + + // Register the restaurant booking tool from README example + mcpServer.tool( + 'book-restaurant', + { + restaurant: z.string(), + date: z.string(), + partySize: z.number() + }, + async ({ restaurant, date, partySize }) => { + // Check availability + const available = await checkAvailability(restaurant, date, partySize); + + if (!available) { + // Ask user if they want to try alternative dates + const result = await mcpServer.server.elicitInput({ + message: `No tables available at ${restaurant} on ${date}. Would you like to check alternative dates?`, + requestedSchema: { + type: 'object', + properties: { + checkAlternatives: { + type: 'boolean', + title: 'Check alternative dates', + description: 'Would you like me to check other dates?' + }, + flexibleDates: { + type: 'string', + title: 'Date flexibility', + description: 'How flexible are your dates?', + enum: ['next_day', 'same_week', 'next_week'], + enumNames: ['Next day', 'Same week', 'Next week'] + } + }, + required: ['checkAlternatives'] + } + }); + + if (result.action === 'accept' && result.content?.checkAlternatives) { + const alternatives = await findAlternatives(restaurant, date, partySize, result.content.flexibleDates as string); + return { + content: [ + { + type: 'text', + text: `Found these alternatives: ${alternatives.join(', ')}` + } + ] + }; + } + + return { + content: [ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ] + }; + } + + await makeBooking(restaurant, date, partySize); + return { + content: [ + { + type: 'text', + text: `Booked table for ${partySize} at ${restaurant} on ${date}` + } + ] + }; + } + ); + + // Create client with elicitation capability + client = new Client( + { + name: 'test-client', + version: '1.0.0' + }, + { + capabilities: { + elicitation: {} + } + } + ); + }); + + test('should successfully elicit additional information', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + findAlternatives.mockResolvedValue(['2024-12-26', '2024-12-27', '2024-12-28']); + + // Set up client to accept alternative date checking + client.setRequestHandler(ElicitRequestSchema, async request => { + expect(request.params.message).toContain('No tables available at ABC Restaurant on 2024-12-25'); + return { + action: 'accept', + content: { + checkAlternatives: true, + flexibleDates: 'same_week' + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); + + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2, 'same_week'); + expect(result.content).toEqual([ + { + type: 'text', + text: 'Found these alternatives: 2024-12-26, 2024-12-27, 2024-12-28' + } + ]); + }); + + test('should handle user declining to elicitation request', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + + // Set up client to reject alternative date checking + client.setRequestHandler(ElicitRequestSchema, async () => { + return { + action: 'accept', + content: { + checkAlternatives: false + } + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); + + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).not.toHaveBeenCalled(); + expect(result.content).toEqual([ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ]); + }); + + test('should handle user cancelling the elicitation', async () => { + // Mock availability check to return false + checkAvailability.mockResolvedValue(false); + + // Set up client to cancel the elicitation + client.setRequestHandler(ElicitRequestSchema, async () => { + return { + action: 'cancel' + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([client.connect(clientTransport), mcpServer.server.connect(serverTransport)]); + + // Call the tool + const result = await client.callTool({ + name: 'book-restaurant', + arguments: { + restaurant: 'ABC Restaurant', + date: '2024-12-25', + partySize: 2 + } + }); + + expect(checkAvailability).toHaveBeenCalledWith('ABC Restaurant', '2024-12-25', 2); + expect(findAlternatives).not.toHaveBeenCalled(); + expect(result.content).toEqual([ + { + type: 'text', + text: 'No booking made. Original date not available.' + } + ]); + }); +}); + +describe('Tools with union and intersection schemas', () => { + test('should support union schemas', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const unionSchema = z.union([ + z.object({ type: z.literal('email'), email: z.string().email() }), + z.object({ type: z.literal('phone'), phone: z.string() }) + ]); + + server.registerTool('contact', { inputSchema: unionSchema }, async args => { + if (args.type === 'email') { + return { + content: [{ type: 'text', text: `Email contact: ${args.email}` }] + }; + } else { + return { + content: [{ type: 'text', text: `Phone contact: ${args.phone}` }] + }; + } + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const emailResult = await client.callTool({ + name: 'contact', + arguments: { + type: 'email', + email: 'test@example.com' + } + }); + + expect(emailResult.content).toEqual([ + { + type: 'text', + text: 'Email contact: test@example.com' + } + ]); + + const phoneResult = await client.callTool({ + name: 'contact', + arguments: { + type: 'phone', + phone: '+1234567890' + } + }); + + expect(phoneResult.content).toEqual([ + { + type: 'text', + text: 'Phone contact: +1234567890' + } + ]); + }); + + test('should support intersection schemas', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const baseSchema = z.object({ id: z.string() }); + const extendedSchema = z.object({ name: z.string(), age: z.number() }); + const intersectionSchema = z.intersection(baseSchema, extendedSchema); + + server.registerTool('user', { inputSchema: intersectionSchema }, async args => { + return { + content: [ + { + type: 'text', + text: `User: ${args.id}, ${args.name}, ${args.age} years old` + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const result = await client.callTool({ + name: 'user', + arguments: { + id: '123', + name: 'John Doe', + age: 30 + } + }); + + expect(result.content).toEqual([ + { + type: 'text', + text: 'User: 123, John Doe, 30 years old' + } + ]); + }); + + test('should support complex nested schemas', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const schema = z.object({ + items: z.array( + z.union([ + z.object({ type: z.literal('text'), content: z.string() }), + z.object({ type: z.literal('number'), value: z.number() }) + ]) + ) + }); + + server.registerTool('process', { inputSchema: schema }, async args => { + const processed = args.items.map(item => { + if (item.type === 'text') { + return item.content.toUpperCase(); + } else { + return item.value * 2; + } + }); + return { + content: [ + { + type: 'text', + text: `Processed: ${processed.join(', ')}` + } + ] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const result = await client.callTool({ + name: 'process', + arguments: { + items: [ + { type: 'text', content: 'hello' }, + { type: 'number', value: 5 }, + { type: 'text', content: 'world' } + ] + } + }); + + expect(result.content).toEqual([ + { + type: 'text', + text: 'Processed: HELLO, 10, WORLD' + } + ]); + }); + + test('should validate union schema inputs correctly', async () => { + const server = new McpServer({ + name: 'test', + version: '1.0.0' + }); + + const client = new Client({ + name: 'test-client', + version: '1.0.0' + }); + + const unionSchema = z.union([ + z.object({ type: z.literal('a'), value: z.string() }), + z.object({ type: z.literal('b'), value: z.number() }) + ]); + + server.registerTool('union-test', { inputSchema: unionSchema }, async () => { + return { + content: [{ type: 'text', text: 'Success' }] + }; + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + await server.connect(serverTransport); + await client.connect(clientTransport); + + const invalidTypeResult = await client.callTool({ + name: 'union-test', + arguments: { + type: 'a', + value: 123 + } + }); + + expect(invalidTypeResult.isError).toBe(true); + expect(invalidTypeResult.content).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: expect.stringContaining('Input validation error') + }) + ]) + ); + + const invalidDiscriminatorResult = await client.callTool({ + name: 'union-test', + arguments: { + type: 'c', + value: 'test' + } + }); + + expect(invalidDiscriminatorResult.isError).toBe(true); + expect(invalidDiscriminatorResult.content).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + type: 'text', + text: expect.stringContaining('Input validation error') + }) + ]) + ); + }); +}); diff --git a/src/server/v3/sse.v3.test.ts b/src/server/v3/sse.v3.test.ts new file mode 100644 index 000000000..be19726e8 --- /dev/null +++ b/src/server/v3/sse.v3.test.ts @@ -0,0 +1,711 @@ +import http from 'http'; +import { type Mocked } from 'vitest'; + +import { SSEServerTransport } from '../sse.js'; +import { McpServer } from '../mcp.js'; +import { createServer, type Server } from 'node:http'; +import { AddressInfo } from 'node:net'; +import * as z from 'zod/v3'; +import { CallToolResult, JSONRPCMessage } from '../../types.js'; + +const createMockResponse = () => { + const res = { + writeHead: vi.fn().mockReturnThis(), + write: vi.fn().mockReturnThis(), + on: vi.fn().mockReturnThis(), + end: vi.fn().mockReturnThis() + }; + + return res as unknown as Mocked; +}; + +const createMockRequest = ({ headers = {}, body }: { headers?: Record; body?: string } = {}) => { + const mockReq = { + headers, + body: body ? body : undefined, + auth: { + token: 'test-token' + }, + on: vi.fn().mockImplementation((event, listener) => { + const mockListener = listener as unknown as (...args: unknown[]) => void; + if (event === 'data') { + mockListener(Buffer.from(body || '') as unknown as Error); + } + if (event === 'error') { + mockListener(new Error('test')); + } + if (event === 'end') { + mockListener(); + } + if (event === 'close') { + setTimeout(listener, 100); + } + return mockReq; + }), + listeners: vi.fn(), + removeListener: vi.fn() + } as unknown as http.IncomingMessage; + + return mockReq; +}; + +/** + * Helper to create and start test HTTP server with MCP setup + */ +async function createTestServerWithSse(args: { mockRes: http.ServerResponse }): Promise<{ + server: Server; + transport: SSEServerTransport; + mcpServer: McpServer; + baseUrl: URL; + sessionId: string; + serverPort: number; +}> { + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + mcpServer.tool( + 'greet', + 'A simple greeting tool', + { name: z.string().describe('Name to greet') }, + async ({ name }): Promise => { + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + } + ); + + const endpoint = '/messages'; + + const transport = new SSEServerTransport(endpoint, args.mockRes); + const sessionId = transport.sessionId; + + await mcpServer.connect(transport); + + const server = createServer(async (req, res) => { + try { + await transport.handlePostMessage(req, res); + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); + + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + const port = (server.address() as AddressInfo).port; + + return { server, transport, mcpServer, baseUrl, sessionId, serverPort: port }; +} + +async function readAllSSEEvents(response: Response): Promise { + const reader = response.body?.getReader(); + if (!reader) throw new Error('No readable stream'); + + const events: string[] = []; + const decoder = new TextDecoder(); + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + if (value) { + events.push(decoder.decode(value)); + } + } + } finally { + reader.releaseLock(); + } + + return events; +} + +/** + * Helper to send JSON-RPC request + */ +async function sendSsePostRequest( + baseUrl: URL, + message: JSONRPCMessage | JSONRPCMessage[], + sessionId?: string, + extraHeaders?: Record +): Promise { + const headers: Record = { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + ...extraHeaders + }; + + if (sessionId) { + baseUrl.searchParams.set('sessionId', sessionId); + } + + return fetch(baseUrl, { + method: 'POST', + headers, + body: JSON.stringify(message) + }); +} + +describe('SSEServerTransport', () => { + async function initializeServer(baseUrl: URL): Promise { + const response = await sendSsePostRequest(baseUrl, { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: '2025-03-26', + capabilities: {} + }, + + id: 'init-1' + } as JSONRPCMessage); + + expect(response.status).toBe(202); + + const text = await readAllSSEEvents(response); + + expect(text).toHaveLength(1); + expect(text[0]).toBe('Accepted'); + } + + describe('start method', () => { + it('should correctly append sessionId to a simple relative endpoint', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${expectedSessionId}\n\n`); + }); + + it('should correctly append sessionId to an endpoint with existing query parameters', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages?foo=bar&baz=qux'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?foo=bar&baz=qux&sessionId=${expectedSessionId}\n\n` + ); + }); + + it('should correctly append sessionId to an endpoint with a hash fragment', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages#section1'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${expectedSessionId}#section1\n\n`); + }); + + it('should correctly append sessionId to an endpoint with query parameters and a hash fragment', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages?key=value#section2'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith( + `event: endpoint\ndata: /messages?key=value&sessionId=${expectedSessionId}#section2\n\n` + ); + }); + + it('should correctly handle the root path endpoint "/"', async () => { + const mockRes = createMockResponse(); + const endpoint = '/'; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`); + }); + + it('should correctly handle an empty string endpoint ""', async () => { + const mockRes = createMockResponse(); + const endpoint = ''; + const transport = new SSEServerTransport(endpoint, mockRes); + const expectedSessionId = transport.sessionId; + + await transport.start(); + + expect(mockRes.writeHead).toHaveBeenCalledWith(200, expect.any(Object)); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /?sessionId=${expectedSessionId}\n\n`); + }); + + /** + * Test: Tool With Request Info + */ + it('should pass request info to tool callback', async () => { + const mockRes = createMockResponse(); + const { mcpServer, baseUrl, sessionId, serverPort } = await createTestServerWithSse({ mockRes }); + await initializeServer(baseUrl); + + mcpServer.tool( + 'test-request-info', + 'A simple test tool with request info', + { name: z.string().describe('Name to greet') }, + async ({ name }, { requestInfo }): Promise => { + return { + content: [ + { type: 'text', text: `Hello, ${name}!` }, + { type: 'text', text: `${JSON.stringify(requestInfo)}` } + ] + }; + } + ); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'test-request-info', + arguments: { + name: 'Test User' + } + }, + id: 'call-1' + }; + + const response = await sendSsePostRequest(baseUrl, toolCallMessage, sessionId); + + expect(response.status).toBe(202); + + expect(mockRes.write).toHaveBeenCalledWith(`event: endpoint\ndata: /messages?sessionId=${sessionId}\n\n`); + + const expectedMessage = { + result: { + content: [ + { + type: 'text', + text: 'Hello, Test User!' + }, + { + type: 'text', + text: JSON.stringify({ + headers: { + host: `127.0.0.1:${serverPort}`, + connection: 'keep-alive', + 'content-type': 'application/json', + accept: 'application/json, text/event-stream', + 'accept-language': '*', + 'sec-fetch-mode': 'cors', + 'user-agent': 'node', + 'accept-encoding': 'gzip, deflate', + 'content-length': '124' + } + }) + } + ] + }, + jsonrpc: '2.0', + id: 'call-1' + }; + expect(mockRes.write).toHaveBeenCalledWith(`event: message\ndata: ${JSON.stringify(expectedMessage)}\n\n`); + }); + }); + + describe('handlePostMessage method', () => { + it('should return 500 if server has not started', async () => { + const mockReq = createMockRequest(); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + + const error = 'SSE connection not established'; + await expect(transport.handlePostMessage(mockReq, mockRes)).rejects.toThrow(error); + expect(mockRes.writeHead).toHaveBeenCalledWith(500); + expect(mockRes.end).toHaveBeenCalledWith(error); + }); + + it('should return 400 if content-type is not application/json', async () => { + const mockReq = createMockRequest({ headers: { 'content-type': 'text/plain' } }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onerror = vi.fn(); + const error = 'Unsupported content-type: text/plain'; + await expect(transport.handlePostMessage(mockReq, mockRes)).resolves.toBe(undefined); + expect(mockRes.writeHead).toHaveBeenCalledWith(400); + expect(mockRes.end).toHaveBeenCalledWith(expect.stringContaining(error)); + expect(transport.onerror).toHaveBeenCalledWith(new Error(error)); + }); + + it('should return 400 if message has not a valid schema', async () => { + const invalidMessage = JSON.stringify({ + // missing jsonrpc field + method: 'call', + params: [1, 2, 3], + id: 1 + }); + const mockReq = createMockRequest({ + headers: { 'content-type': 'application/json' }, + body: invalidMessage + }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onmessage = vi.fn(); + await transport.handlePostMessage(mockReq, mockRes); + expect(mockRes.writeHead).toHaveBeenCalledWith(400); + expect(transport.onmessage).not.toHaveBeenCalled(); + expect(mockRes.end).toHaveBeenCalledWith(`Invalid message: ${invalidMessage}`); + }); + + it('should return 202 if message has a valid schema', async () => { + const validMessage = JSON.stringify({ + jsonrpc: '2.0', + method: 'call', + params: { + a: 1, + b: 2, + c: 3 + }, + id: 1 + }); + const mockReq = createMockRequest({ + headers: { 'content-type': 'application/json' }, + body: validMessage + }); + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + + transport.onmessage = vi.fn(); + await transport.handlePostMessage(mockReq, mockRes); + expect(mockRes.writeHead).toHaveBeenCalledWith(202); + expect(mockRes.end).toHaveBeenCalledWith('Accepted'); + expect(transport.onmessage).toHaveBeenCalledWith( + { + jsonrpc: '2.0', + method: 'call', + params: { + a: 1, + b: 2, + c: 3 + }, + id: 1 + }, + { + authInfo: { + token: 'test-token' + }, + requestInfo: { + headers: { + 'content-type': 'application/json' + } + } + } + ); + }); + }); + + describe('close method', () => { + it('should call onclose', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + transport.onclose = vi.fn(); + await transport.close(); + expect(transport.onclose).toHaveBeenCalled(); + }); + }); + + describe('send method', () => { + it('should call onsend', async () => { + const mockRes = createMockResponse(); + const endpoint = '/messages'; + const transport = new SSEServerTransport(endpoint, mockRes); + await transport.start(); + expect(mockRes.write).toHaveBeenCalledTimes(1); + expect(mockRes.write).toHaveBeenCalledWith(expect.stringContaining('event: endpoint')); + expect(mockRes.write).toHaveBeenCalledWith(expect.stringContaining(`data: /messages?sessionId=${transport.sessionId}`)); + }); + }); + + describe('DNS rebinding protection', () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + describe('Host header validation', () => { + it('should accept requests with allowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000', 'example.com'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed host headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'evil.com', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + }); + + it('should reject requests without host header when allowedHosts is configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Host header: undefined'); + }); + }); + + describe('Origin header validation', () => { + it('should accept requests with allowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + origin: 'http://localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with disallowed origin headers', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + origin: 'http://evil.com', + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + }); + }); + + describe('Content-Type validation', () => { + it('should accept requests with application/json content-type', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should accept requests with application/json with charset', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'application/json; charset=utf-8' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes.end).toHaveBeenCalledWith('Accepted'); + }); + + it('should reject requests with non-application/json content-type when protection is enabled', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + 'content-type': 'text/plain' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); + }); + }); + + describe('enableDnsRebindingProtection option', () => { + it('should skip all validations when enableDnsRebindingProtection is false', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: false + }); + await transport.start(); + + const mockReq = createMockRequest({ + headers: { + host: 'evil.com', + origin: 'http://evil.com', + 'content-type': 'text/plain' + } + }); + const mockHandleRes = createMockResponse(); + + await transport.handlePostMessage(mockReq, mockHandleRes, { jsonrpc: '2.0', method: 'test' }); + + // Should pass even with invalid headers because protection is disabled + expect(mockHandleRes.writeHead).toHaveBeenCalledWith(400); + // The error should be from content-type parsing, not DNS rebinding protection + expect(mockHandleRes.end).toHaveBeenCalledWith('Error: Unsupported content-type: text/plain'); + }); + }); + + describe('Combined validations', () => { + it('should validate both host and origin when both are configured', async () => { + const mockRes = createMockResponse(); + const transport = new SSEServerTransport('/messages', mockRes, { + allowedHosts: ['localhost:3000'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true + }); + await transport.start(); + + // Valid host, invalid origin + const mockReq1 = createMockRequest({ + headers: { + host: 'localhost:3000', + origin: 'http://evil.com', + 'content-type': 'application/json' + } + }); + const mockHandleRes1 = createMockResponse(); + + await transport.handlePostMessage(mockReq1, mockHandleRes1, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes1.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes1.end).toHaveBeenCalledWith('Invalid Origin header: http://evil.com'); + + // Invalid host, valid origin + const mockReq2 = createMockRequest({ + headers: { + host: 'evil.com', + origin: 'http://localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes2 = createMockResponse(); + + await transport.handlePostMessage(mockReq2, mockHandleRes2, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes2.writeHead).toHaveBeenCalledWith(403); + expect(mockHandleRes2.end).toHaveBeenCalledWith('Invalid Host header: evil.com'); + + // Both valid + const mockReq3 = createMockRequest({ + headers: { + host: 'localhost:3000', + origin: 'http://localhost:3000', + 'content-type': 'application/json' + } + }); + const mockHandleRes3 = createMockResponse(); + + await transport.handlePostMessage(mockReq3, mockHandleRes3, { jsonrpc: '2.0', method: 'test' }); + + expect(mockHandleRes3.writeHead).toHaveBeenCalledWith(202); + expect(mockHandleRes3.end).toHaveBeenCalledWith('Accepted'); + }); + }); + }); +}); diff --git a/src/server/v3/streamableHttp.v3.test.ts b/src/server/v3/streamableHttp.v3.test.ts new file mode 100644 index 000000000..524069080 --- /dev/null +++ b/src/server/v3/streamableHttp.v3.test.ts @@ -0,0 +1,2151 @@ +import { createServer, type Server, IncomingMessage, ServerResponse } from 'node:http'; +import { createServer as netCreateServer, AddressInfo } from 'node:net'; +import { randomUUID } from 'node:crypto'; +import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from '../streamableHttp.js'; +import { McpServer } from '../mcp.js'; +import { CallToolResult, JSONRPCMessage } from '../../types.js'; +import * as z from 'zod/v3'; +import { AuthInfo } from '../auth/types.js'; + +async function getFreePort() { + return new Promise(res => { + const srv = netCreateServer(); + srv.listen(0, () => { + const address = srv.address()!; + if (typeof address === 'string') { + throw new Error('Unexpected address type: ' + typeof address); + } + const port = (address as AddressInfo).port; + srv.close(_err => res(port)); + }); + }); +} + +/** + * Test server configuration for StreamableHTTPServerTransport tests + */ +interface TestServerConfig { + sessionIdGenerator: (() => string) | undefined; + enableJsonResponse?: boolean; + customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise; + eventStore?: EventStore; + onsessioninitialized?: (sessionId: string) => void | Promise; + onsessionclosed?: (sessionId: string) => void | Promise; +} + +/** + * Helper to create and start test HTTP server with MCP setup + */ +async function createTestServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; +}> { + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + mcpServer.tool( + 'greet', + 'A simple greeting tool', + { name: z.string().describe('Name to greet') }, + async ({ name }): Promise => { + return { content: [{ type: 'text', text: `Hello, ${name}!` }] }; + } + ); + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + enableJsonResponse: config.enableJsonResponse ?? false, + eventStore: config.eventStore, + onsessioninitialized: config.onsessioninitialized, + onsessionclosed: config.onsessionclosed + }); + + await mcpServer.connect(transport); + + const server = createServer(async (req, res) => { + try { + if (config.customRequestHandler) { + await config.customRequestHandler(req, res); + } else { + await transport.handleRequest(req, res); + } + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); + + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, transport, mcpServer, baseUrl }; +} + +/** + * Helper to create and start authenticated test HTTP server with MCP setup + */ +async function createTestAuthServer(config: TestServerConfig = { sessionIdGenerator: () => randomUUID() }): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; +}> { + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + mcpServer.tool( + 'profile', + 'A user profile data tool', + { active: z.boolean().describe('Profile status') }, + async ({ active }, { authInfo }): Promise => { + return { content: [{ type: 'text', text: `${active ? 'Active' : 'Inactive'} profile from token: ${authInfo?.token}!` }] }; + } + ); + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + enableJsonResponse: config.enableJsonResponse ?? false, + eventStore: config.eventStore, + onsessioninitialized: config.onsessioninitialized, + onsessionclosed: config.onsessionclosed + }); + + await mcpServer.connect(transport); + + const server = createServer(async (req: IncomingMessage & { auth?: AuthInfo }, res) => { + try { + if (config.customRequestHandler) { + await config.customRequestHandler(req, res); + } else { + req.auth = { token: req.headers['authorization']?.split(' ')[1] } as AuthInfo; + await transport.handleRequest(req, res); + } + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }); + + const baseUrl = await new Promise(resolve => { + server.listen(0, '127.0.0.1', () => { + const addr = server.address() as AddressInfo; + resolve(new URL(`http://127.0.0.1:${addr.port}`)); + }); + }); + + return { server, transport, mcpServer, baseUrl }; +} + +/** + * Helper to stop test server + */ +async function stopTestServer({ server, transport }: { server: Server; transport: StreamableHTTPServerTransport }): Promise { + // First close the transport to ensure all SSE streams are closed + await transport.close(); + + // Close the server without waiting indefinitely + server.close(); +} + +/** + * Common test messages + */ +const TEST_MESSAGES = { + initialize: { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client', version: '1.0' }, + protocolVersion: '2025-03-26', + capabilities: {} + }, + + id: 'init-1' + } as JSONRPCMessage, + + toolsList: { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'tools-1' + } as JSONRPCMessage +}; + +/** + * Helper to extract text from SSE response + * Note: Can only be called once per response stream. For multiple reads, + * get the reader manually and read multiple times. + */ +async function readSSEEvent(response: Response): Promise { + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + return new TextDecoder().decode(value); +} + +/** + * Helper to send JSON-RPC request + */ +async function sendPostRequest( + baseUrl: URL, + message: JSONRPCMessage | JSONRPCMessage[], + sessionId?: string, + extraHeaders?: Record +): Promise { + const headers: Record = { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + ...extraHeaders + }; + + if (sessionId) { + headers['mcp-session-id'] = sessionId; + // After initialization, include the protocol version header + headers['mcp-protocol-version'] = '2025-03-26'; + } + + return fetch(baseUrl, { + method: 'POST', + headers, + body: JSON.stringify(message) + }); +} + +function expectErrorResponse(data: unknown, expectedCode: number, expectedMessagePattern: RegExp): void { + expect(data).toMatchObject({ + jsonrpc: '2.0', + error: expect.objectContaining({ + code: expectedCode, + message: expect.stringMatching(expectedMessagePattern) + }) + }); +} + +describe('StreamableHTTPServerTransport', () => { + let server: Server; + let mcpServer: McpServer; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + + beforeEach(async () => { + const result = await createTestServer(); + server = result.server; + transport = result.transport; + mcpServer = result.mcpServer; + baseUrl = result.baseUrl; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + async function initializeServer(): Promise { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + expect(response.status).toBe(200); + const newSessionId = response.headers.get('mcp-session-id'); + expect(newSessionId).toBeDefined(); + return newSessionId as string; + } + + it('should initialize server and generate session ID', async () => { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('text/event-stream'); + expect(response.headers.get('mcp-session-id')).toBeDefined(); + }); + + it('should reject second initialization request', async () => { + // First initialize + const sessionId = await initializeServer(); + expect(sessionId).toBeDefined(); + + // Try second initialize + const secondInitMessage = { + ...TEST_MESSAGES.initialize, + id: 'second-init' + }; + + const response = await sendPostRequest(baseUrl, secondInitMessage); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32600, /Server already initialized/); + }); + + it('should reject batch initialize request', async () => { + const batchInitMessages: JSONRPCMessage[] = [ + TEST_MESSAGES.initialize, + { + jsonrpc: '2.0', + method: 'initialize', + params: { + clientInfo: { name: 'test-client-2', version: '1.0' }, + protocolVersion: '2025-03-26' + }, + id: 'init-2' + } + ]; + + const response = await sendPostRequest(baseUrl, batchInitMessages); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32600, /Only one initialization request is allowed/); + }); + + it('should handle post requests via sse response correctly', async () => { + sessionId = await initializeServer(); + + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); + + expect(response.status).toBe(200); + + // Read the SSE stream for the response + const text = await readSSEEvent(response); + + // Parse the SSE event + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: expect.objectContaining({ + tools: expect.arrayContaining([ + expect.objectContaining({ + name: 'greet', + description: 'A simple greeting tool' + }) + ]) + }), + id: 'tools-1' + }); + }); + + it('should call a tool and return the result', async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'greet', + arguments: { + name: 'Test User' + } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { + type: 'text', + text: 'Hello, Test User!' + } + ] + }, + id: 'call-1' + }); + }); + + /*** + * Test: Tool With Request Info + */ + it('should pass request info to tool callback', async () => { + sessionId = await initializeServer(); + + mcpServer.tool( + 'test-request-info', + 'A simple test tool with request info', + { name: z.string().describe('Name to greet') }, + async ({ name }, { requestInfo }): Promise => { + return { + content: [ + { type: 'text', text: `Hello, ${name}!` }, + { type: 'text', text: `${JSON.stringify(requestInfo)}` } + ] + }; + } + ); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'test-request-info', + arguments: { + name: 'Test User' + } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { type: 'text', text: 'Hello, Test User!' }, + { type: 'text', text: expect.any(String) } + ] + }, + id: 'call-1' + }); + + const requestInfo = JSON.parse(eventData.result.content[1].text); + expect(requestInfo).toMatchObject({ + headers: { + 'content-type': 'application/json', + accept: 'application/json, text/event-stream', + connection: 'keep-alive', + 'mcp-session-id': sessionId, + 'accept-language': '*', + 'user-agent': expect.any(String), + 'accept-encoding': expect.any(String), + 'content-length': expect.any(String) + } + }); + }); + + it('should reject requests without a valid session ID', async () => { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request/); + expect(errorData.id).toBeNull(); + }); + + it('should reject invalid session ID', async () => { + // First initialize to be in valid state + await initializeServer(); + + // Now try with invalid session ID + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, 'invalid-session-id'); + + expect(response.status).toBe(404); + const errorData = await response.json(); + expectErrorResponse(errorData, -32001, /Session not found/); + }); + + it('should establish standalone SSE stream and receive server-initiated messages', async () => { + // First initialize to get a session ID + sessionId = await initializeServer(); + + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(sseResponse.status).toBe(200); + expect(sseResponse.headers.get('content-type')).toBe('text/event-stream'); + + // Send a notification (server-initiated message) that should appear on SSE stream + const notification: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Test notification' } + }; + + // Send the notification via transport + await transport.send(notification); + + // Read from the stream and verify we got the notification + const text = await readSSEEvent(sseResponse); + + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Test notification' } + }); + }); + + it('should not close GET SSE stream after sending multiple server notifications', async () => { + sessionId = await initializeServer(); + + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(sseResponse.status).toBe(200); + const reader = sseResponse.body?.getReader(); + + // Send multiple notifications + const notification1: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'First notification' } + }; + + // Just send one and verify it comes through - then the stream should stay open + await transport.send(notification1); + + const { value, done } = await reader!.read(); + const text = new TextDecoder().decode(value); + expect(text).toContain('First notification'); + expect(done).toBe(false); // Stream should still be open + }); + + it('should reject second SSE stream for the same session', async () => { + sessionId = await initializeServer(); + + // Open first SSE stream + const firstStream = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(firstStream.status).toBe(200); + + // Try to open a second SSE stream with the same session ID + const secondStream = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + // Should be rejected + expect(secondStream.status).toBe(409); // Conflict + const errorData = await secondStream.json(); + expectErrorResponse(errorData, -32000, /Only one SSE stream is allowed per session/); + }); + + it('should reject GET requests without Accept: text/event-stream header', async () => { + sessionId = await initializeServer(); + + // Try GET without proper Accept header + const response = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'application/json', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(response.status).toBe(406); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Client must accept text\/event-stream/); + }); + + it('should reject POST requests without proper Accept header', async () => { + sessionId = await initializeServer(); + + // Try POST without Accept: text/event-stream + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json', // Missing text/event-stream + 'mcp-session-id': sessionId + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + expect(response.status).toBe(406); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Client must accept both application\/json and text\/event-stream/); + }); + + it('should reject unsupported Content-Type', async () => { + sessionId = await initializeServer(); + + // Try POST with text/plain Content-Type + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'text/plain', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: 'This is plain text' + }); + + expect(response.status).toBe(415); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Content-Type must be application\/json/); + }); + + it('should handle JSON-RPC batch notification messages with 202 response', async () => { + sessionId = await initializeServer(); + + // Send batch of notifications (no IDs) + const batchNotifications: JSONRPCMessage[] = [ + { jsonrpc: '2.0', method: 'someNotification1', params: {} }, + { jsonrpc: '2.0', method: 'someNotification2', params: {} } + ]; + const response = await sendPostRequest(baseUrl, batchNotifications, sessionId); + + expect(response.status).toBe(202); + }); + + it('should handle batch request messages with SSE stream for responses', async () => { + sessionId = await initializeServer(); + + // Send batch of requests + const batchRequests: JSONRPCMessage[] = [ + { jsonrpc: '2.0', method: 'tools/list', params: {}, id: 'req-1' }, + { jsonrpc: '2.0', method: 'tools/call', params: { name: 'greet', arguments: { name: 'BatchUser' } }, id: 'req-2' } + ]; + const response = await sendPostRequest(baseUrl, batchRequests, sessionId); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('text/event-stream'); + + const reader = response.body?.getReader(); + + // The responses may come in any order or together in one chunk + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Check that both responses were sent on the same stream + expect(text).toContain('"id":"req-1"'); + expect(text).toContain('"tools"'); // tools/list result + expect(text).toContain('"id":"req-2"'); + expect(text).toContain('Hello, BatchUser'); // tools/call result + }); + + it('should properly handle invalid JSON data', async () => { + sessionId = await initializeServer(); + + // Send invalid JSON + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: 'This is not valid JSON' + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32700, /Parse error/); + }); + + it('should return 400 error for invalid JSON-RPC messages', async () => { + sessionId = await initializeServer(); + + // Invalid JSON-RPC (missing required jsonrpc version) + const invalidMessage = { method: 'tools/list', params: {}, id: 1 }; // missing jsonrpc version + const response = await sendPostRequest(baseUrl, invalidMessage as JSONRPCMessage, sessionId); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expect(errorData).toMatchObject({ + jsonrpc: '2.0', + error: expect.anything() + }); + }); + + it('should reject requests to uninitialized server', async () => { + // Create a new HTTP server and transport without initializing + const { server: uninitializedServer, transport: uninitializedTransport, baseUrl: uninitializedUrl } = await createTestServer(); + // Transport not used in test but needed for cleanup + + // No initialization, just send a request directly + const uninitializedMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'uninitialized-test' + }; + + // Send a request to uninitialized server + const response = await sendPostRequest(uninitializedUrl, uninitializedMessage, 'any-session-id'); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Server not initialized/); + + // Cleanup + await stopTestServer({ server: uninitializedServer, transport: uninitializedTransport }); + }); + + it('should send response messages to the connection that sent the request', async () => { + sessionId = await initializeServer(); + + const message1: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'req-1' + }; + + const message2: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'greet', + arguments: { name: 'Connection2' } + }, + id: 'req-2' + }; + + // Make two concurrent fetch connections for different requests + const req1 = sendPostRequest(baseUrl, message1, sessionId); + const req2 = sendPostRequest(baseUrl, message2, sessionId); + + // Get both responses + const [response1, response2] = await Promise.all([req1, req2]); + const reader1 = response1.body?.getReader(); + const reader2 = response2.body?.getReader(); + + // Read responses from each stream (requires each receives its specific response) + const { value: value1 } = await reader1!.read(); + const text1 = new TextDecoder().decode(value1); + expect(text1).toContain('"id":"req-1"'); + expect(text1).toContain('"tools"'); // tools/list result + + const { value: value2 } = await reader2!.read(); + const text2 = new TextDecoder().decode(value2); + expect(text2).toContain('"id":"req-2"'); + expect(text2).toContain('Hello, Connection2'); // tools/call result + }); + + it('should keep stream open after sending server notifications', async () => { + sessionId = await initializeServer(); + + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + // Send several server-initiated notifications + await transport.send({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'First notification' } + }); + + await transport.send({ + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Second notification' } + }); + + // Stream should still be open - it should not close after sending notifications + expect(sseResponse.bodyUsed).toBe(false); + }); + + // The current implementation will close the entire transport for DELETE + // Creating a temporary transport/server where we don't care if it gets closed + it('should properly handle DELETE requests and close session', async () => { + // Setup a temporary server for this test + const tempResult = await createTestServer(); + const tempServer = tempResult.server; + const tempUrl = tempResult.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + // Now DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + + // Clean up - don't wait indefinitely for server close + tempServer.close(); + }); + + it('should reject DELETE requests with invalid session ID', async () => { + // Initialize the server first to activate it + sessionId = await initializeServer(); + + // Try to delete with invalid session ID + const response = await fetch(baseUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': 'invalid-session-id', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(response.status).toBe(404); + const errorData = await response.json(); + expectErrorResponse(errorData, -32001, /Session not found/); + }); + + describe('protocol version header validation', () => { + it('should accept requests with matching protocol version', async () => { + sessionId = await initializeServer(); + + // Send request with matching protocol version + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList, sessionId); + + expect(response.status).toBe(200); + }); + + it('should accept requests without protocol version header', async () => { + sessionId = await initializeServer(); + + // Send request without protocol version header + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + // No mcp-protocol-version header + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + expect(response.status).toBe(200); + }); + + it('should reject requests with unsupported protocol version', async () => { + sessionId = await initializeServer(); + + // Send request with unsupported protocol version + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '1999-01-01' // Unsupported version + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + + it('should accept when protocol version differs from negotiated version', async () => { + sessionId = await initializeServer(); + + // Spy on console.warn to verify warning is logged + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}); + + // Send request with different but supported protocol version + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2024-11-05' // Different but supported version + }, + body: JSON.stringify(TEST_MESSAGES.toolsList) + }); + + // Request should still succeed + expect(response.status).toBe(200); + + warnSpy.mockRestore(); + }); + + it('should handle protocol version validation for GET requests', async () => { + sessionId = await initializeServer(); + + // GET request with unsupported protocol version + const response = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': 'invalid-version' + } + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + + it('should handle protocol version validation for DELETE requests', async () => { + sessionId = await initializeServer(); + + // DELETE request with unsupported protocol version + const response = await fetch(baseUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': sessionId, + 'mcp-protocol-version': 'invalid-version' + } + }); + + expect(response.status).toBe(400); + const errorData = await response.json(); + expectErrorResponse(errorData, -32000, /Bad Request: Unsupported protocol version \(supported versions: .+\)/); + }); + }); +}); + +describe('StreamableHTTPServerTransport with AuthInfo', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + + beforeEach(async () => { + const result = await createTestAuthServer(); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + async function initializeServer(): Promise { + const response = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + expect(response.status).toBe(200); + const newSessionId = response.headers.get('mcp-session-id'); + expect(newSessionId).toBeDefined(); + return newSessionId as string; + } + + it('should call a tool with authInfo', async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'profile', + arguments: { active: true } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId, { authorization: 'Bearer test-token' }); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { + type: 'text', + text: 'Active profile from token: test-token!' + } + ] + }, + id: 'call-1' + }); + }); + + it('should calls tool without authInfo when it is optional', async () => { + sessionId = await initializeServer(); + + const toolCallMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/call', + params: { + name: 'profile', + arguments: { active: false } + }, + id: 'call-1' + }; + + const response = await sendPostRequest(baseUrl, toolCallMessage, sessionId); + expect(response.status).toBe(200); + + const text = await readSSEEvent(response); + const eventLines = text.split('\n'); + const dataLine = eventLines.find(line => line.startsWith('data:')); + expect(dataLine).toBeDefined(); + + const eventData = JSON.parse(dataLine!.substring(5)); + expect(eventData).toMatchObject({ + jsonrpc: '2.0', + result: { + content: [ + { + type: 'text', + text: 'Inactive profile from token: undefined!' + } + ] + }, + id: 'call-1' + }); + }); +}); + +// Test JSON Response Mode +describe('StreamableHTTPServerTransport with JSON Response Mode', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + + beforeEach(async () => { + const result = await createTestServer({ sessionIdGenerator: () => randomUUID(), enableJsonResponse: true }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Initialize and get session ID + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + sessionId = initResponse.headers.get('mcp-session-id') as string; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + it('should return JSON response for a single request', async () => { + const toolsListMessage: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'json-req-1' + }; + + const response = await sendPostRequest(baseUrl, toolsListMessage, sessionId); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('application/json'); + + const result = await response.json(); + expect(result).toMatchObject({ + jsonrpc: '2.0', + result: expect.objectContaining({ + tools: expect.arrayContaining([expect.objectContaining({ name: 'greet' })]) + }), + id: 'json-req-1' + }); + }); + + it('should return JSON response for batch requests', async () => { + const batchMessages: JSONRPCMessage[] = [ + { jsonrpc: '2.0', method: 'tools/list', params: {}, id: 'batch-1' }, + { jsonrpc: '2.0', method: 'tools/call', params: { name: 'greet', arguments: { name: 'JSON' } }, id: 'batch-2' } + ]; + + const response = await sendPostRequest(baseUrl, batchMessages, sessionId); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('application/json'); + + const results = await response.json(); + expect(Array.isArray(results)).toBe(true); + expect(results).toHaveLength(2); + + // Batch responses can come in any order + const listResponse = results.find((r: { id?: string }) => r.id === 'batch-1'); + const callResponse = results.find((r: { id?: string }) => r.id === 'batch-2'); + + expect(listResponse).toEqual( + expect.objectContaining({ + jsonrpc: '2.0', + id: 'batch-1', + result: expect.objectContaining({ + tools: expect.arrayContaining([expect.objectContaining({ name: 'greet' })]) + }) + }) + ); + + expect(callResponse).toEqual( + expect.objectContaining({ + jsonrpc: '2.0', + id: 'batch-2', + result: expect.objectContaining({ + content: expect.arrayContaining([expect.objectContaining({ type: 'text', text: 'Hello, JSON!' })]) + }) + }) + ); + }); +}); + +// Test pre-parsed body handling +describe('StreamableHTTPServerTransport with pre-parsed body', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + let parsedBody: unknown = null; + + beforeEach(async () => { + const result = await createTestServer({ + customRequestHandler: async (req, res) => { + try { + if (parsedBody !== null) { + await transport.handleRequest(req, res, parsedBody); + parsedBody = null; // Reset after use + } else { + await transport.handleRequest(req, res); + } + } catch (error) { + console.error('Error handling request:', error); + if (!res.headersSent) res.writeHead(500).end(); + } + }, + sessionIdGenerator: () => randomUUID() + }); + + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Initialize and get session ID + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + sessionId = initResponse.headers.get('mcp-session-id') as string; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + it('should accept pre-parsed request body', async () => { + // Set up the pre-parsed body + parsedBody = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'preparsed-1' + }; + + // Send an empty body since we'll use pre-parsed body + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + // Empty body - we're testing pre-parsed body + body: '' + }); + + expect(response.status).toBe(200); + expect(response.headers.get('content-type')).toBe('text/event-stream'); + + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Verify the response used the pre-parsed body + expect(text).toContain('"id":"preparsed-1"'); + expect(text).toContain('"tools"'); + }); + + it('should handle pre-parsed batch messages', async () => { + parsedBody = [ + { jsonrpc: '2.0', method: 'tools/list', params: {}, id: 'batch-1' }, + { jsonrpc: '2.0', method: 'tools/call', params: { name: 'greet', arguments: { name: 'PreParsed' } }, id: 'batch-2' } + ]; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: '' // Empty as we're using pre-parsed + }); + + expect(response.status).toBe(200); + + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + expect(text).toContain('"id":"batch-1"'); + expect(text).toContain('"tools"'); + }); + + it('should prefer pre-parsed body over request body', async () => { + // Set pre-parsed to tools/list + parsedBody = { + jsonrpc: '2.0', + method: 'tools/list', + params: {}, + id: 'preparsed-wins' + }; + + // Send actual body with tools/call - should be ignored + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': sessionId + }, + body: JSON.stringify({ + jsonrpc: '2.0', + method: 'tools/call', + params: { name: 'greet', arguments: { name: 'Ignored' } }, + id: 'ignored-id' + }) + }); + + expect(response.status).toBe(200); + + const reader = response.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Should have processed the pre-parsed body + expect(text).toContain('"id":"preparsed-wins"'); + expect(text).toContain('"tools"'); + expect(text).not.toContain('"ignored-id"'); + }); +}); + +// Test resumability support +describe('StreamableHTTPServerTransport with resumability', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + let mcpServer: McpServer; + const storedEvents: Map = new Map(); + + // Simple implementation of EventStore + const eventStore: EventStore = { + async storeEvent(streamId: string, message: JSONRPCMessage): Promise { + const eventId = `${streamId}_${randomUUID()}`; + storedEvents.set(eventId, { eventId, message }); + return eventId; + }, + + async replayEventsAfter( + lastEventId: EventId, + { + send + }: { + send: (eventId: EventId, message: JSONRPCMessage) => Promise; + } + ): Promise { + const streamId = lastEventId.split('_')[0]; + // Extract stream ID from the event ID + // For test simplicity, just return all events with matching streamId that aren't the lastEventId + for (const [eventId, { message }] of storedEvents.entries()) { + if (eventId.startsWith(streamId) && eventId !== lastEventId) { + await send(eventId, message); + } + } + return streamId; + } + }; + + beforeEach(async () => { + storedEvents.clear(); + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + eventStore + }); + + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + mcpServer = result.mcpServer; + + // Verify resumability is enabled on the transport + expect(transport['_eventStore']).toBeDefined(); + + // Initialize the server + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + sessionId = initResponse.headers.get('mcp-session-id') as string; + expect(sessionId).toBeDefined(); + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + storedEvents.clear(); + }); + + it('should store and include event IDs in server SSE messages', async () => { + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(sseResponse.status).toBe(200); + expect(sseResponse.headers.get('content-type')).toBe('text/event-stream'); + + // Send a notification that should be stored with an event ID + const notification: JSONRPCMessage = { + jsonrpc: '2.0', + method: 'notifications/message', + params: { level: 'info', data: 'Test notification with event ID' } + }; + + // Send the notification via transport + await transport.send(notification); + + // Read from the stream and verify we got the notification with an event ID + const reader = sseResponse.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // The response should contain an event ID + expect(text).toContain('id: '); + expect(text).toContain('"method":"notifications/message"'); + + // Extract the event ID + const idMatch = text.match(/id: ([^\n]+)/); + expect(idMatch).toBeTruthy(); + + // Verify the event was stored + const eventId = idMatch![1]; + expect(storedEvents.has(eventId)).toBe(true); + const storedEvent = storedEvents.get(eventId); + expect(eventId.startsWith('_GET_stream')).toBe(true); + expect(storedEvent?.message).toMatchObject(notification); + }); + + it('should store and replay MCP server tool notifications', async () => { + // Establish a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26' + } + }); + expect(sseResponse.status).toBe(200); + + // Send a server notification through the MCP server + await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'First notification from MCP server' }); + + // Read the notification from the SSE stream + const reader = sseResponse.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Verify the notification was sent with an event ID + expect(text).toContain('id: '); + expect(text).toContain('First notification from MCP server'); + + // Extract the event ID + const idMatch = text.match(/id: ([^\n]+)/); + expect(idMatch).toBeTruthy(); + const firstEventId = idMatch![1]; + + // Send a second notification + await mcpServer.server.sendLoggingMessage({ level: 'info', data: 'Second notification from MCP server' }); + + // Close the first SSE stream to simulate a disconnect + await reader!.cancel(); + + // Reconnect with the Last-Event-ID to get missed messages + const reconnectResponse = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-session-id': sessionId, + 'mcp-protocol-version': '2025-03-26', + 'last-event-id': firstEventId + } + }); + + expect(reconnectResponse.status).toBe(200); + + // Read the replayed notification + const reconnectReader = reconnectResponse.body?.getReader(); + const reconnectData = await reconnectReader!.read(); + const reconnectText = new TextDecoder().decode(reconnectData.value); + + // Verify we received the second notification that was sent after our stored eventId + expect(reconnectText).toContain('Second notification from MCP server'); + expect(reconnectText).toContain('id: '); + }); +}); + +// Test stateless mode +describe('StreamableHTTPServerTransport in stateless mode', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + + beforeEach(async () => { + const result = await createTestServer({ sessionIdGenerator: undefined }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + }); + + it('should operate without session ID validation', async () => { + // Initialize the server first + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + expect(initResponse.status).toBe(200); + // Should NOT have session ID header in stateless mode + expect(initResponse.headers.get('mcp-session-id')).toBeNull(); + + // Try request without session ID - should work in stateless mode + const toolsResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.toolsList); + + expect(toolsResponse.status).toBe(200); + }); + + it('should handle POST requests with various session IDs in stateless mode', async () => { + await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + // Try with a random session ID - should be accepted + const response1 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': 'random-id-1' + }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'tools/list', params: {}, id: 't1' }) + }); + expect(response1.status).toBe(200); + + // Try with another random session ID - should also be accepted + const response2 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + 'mcp-session-id': 'different-id-2' + }, + body: JSON.stringify({ jsonrpc: '2.0', method: 'tools/list', params: {}, id: 't2' }) + }); + expect(response2.status).toBe(200); + }); + + it('should reject second SSE stream even in stateless mode', async () => { + // Despite no session ID requirement, the transport still only allows + // one standalone SSE stream at a time + + // Initialize the server first + await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + + // Open first SSE stream + const stream1 = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-protocol-version': '2025-03-26' + } + }); + expect(stream1.status).toBe(200); + + // Open second SSE stream - should still be rejected, stateless mode still only allows one + const stream2 = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream', + 'mcp-protocol-version': '2025-03-26' + } + }); + expect(stream2.status).toBe(409); // Conflict - only one stream allowed + }); +}); + +// Test onsessionclosed callback +describe('StreamableHTTPServerTransport onsessionclosed callback', () => { + it('should call onsessionclosed callback when session is closed via DELETE', async () => { + const mockCallback = vi.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + expect(tempSessionId).toBeDefined(); + + // DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(tempSessionId); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // Clean up + tempServer.close(); + }); + + it('should not call onsessionclosed callback when not provided', async () => { + // Create server without onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID() + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + // DELETE the session - should not throw error + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + + // Clean up + tempServer.close(); + }); + + it('should not call onsessionclosed callback for invalid session DELETE', async () => { + const mockCallback = vi.fn(); + + // Create server with onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a valid session + await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + + // Try to DELETE with invalid session ID + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': 'invalid-session-id', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(404); + expect(mockCallback).not.toHaveBeenCalled(); + + // Clean up + tempServer.close(); + }); + + it('should call onsessionclosed callback with correct session ID when multiple sessions exist', async () => { + const mockCallback = vi.fn(); + + // Create first server + const result1 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const server1 = result1.server; + const url1 = result1.baseUrl; + + // Create second server + const result2 = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: mockCallback + }); + + const server2 = result2.server; + const url2 = result2.baseUrl; + + // Initialize both servers + const initResponse1 = await sendPostRequest(url1, TEST_MESSAGES.initialize); + const sessionId1 = initResponse1.headers.get('mcp-session-id'); + + const initResponse2 = await sendPostRequest(url2, TEST_MESSAGES.initialize); + const sessionId2 = initResponse2.headers.get('mcp-session-id'); + + expect(sessionId1).toBeDefined(); + expect(sessionId2).toBeDefined(); + expect(sessionId1).not.toBe(sessionId2); + + // DELETE first session + const deleteResponse1 = await fetch(url1, { + method: 'DELETE', + headers: { + 'mcp-session-id': sessionId1 || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse1.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId1); + expect(mockCallback).toHaveBeenCalledTimes(1); + + // DELETE second session + const deleteResponse2 = await fetch(url2, { + method: 'DELETE', + headers: { + 'mcp-session-id': sessionId2 || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse2.status).toBe(200); + expect(mockCallback).toHaveBeenCalledWith(sessionId2); + expect(mockCallback).toHaveBeenCalledTimes(2); + + // Clean up + server1.close(); + server2.close(); + }); +}); + +// Test async callbacks for onsessioninitialized and onsessionclosed +describe('StreamableHTTPServerTransport async callbacks', () => { + it('should support async onsessioninitialized callback', async () => { + const initializationOrder: string[] = []; + + // Create server with async onsessioninitialized callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (sessionId: string) => { + initializationOrder.push('async-start'); + // Simulate async operation + await new Promise(resolve => setTimeout(resolve, 10)); + initializationOrder.push('async-end'); + initializationOrder.push(sessionId); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to trigger the callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + // Give time for async callback to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(initializationOrder).toEqual(['async-start', 'async-end', tempSessionId]); + + // Clean up + tempServer.close(); + }); + + it('should support sync onsessioninitialized callback (backwards compatibility)', async () => { + const capturedSessionId: string[] = []; + + // Create server with sync onsessioninitialized callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: (sessionId: string) => { + capturedSessionId.push(sessionId); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to trigger the callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + expect(capturedSessionId).toEqual([tempSessionId]); + + // Clean up + tempServer.close(); + }); + + it('should support async onsessionclosed callback', async () => { + const closureOrder: string[] = []; + + // Create server with async onsessionclosed callback + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: async (sessionId: string) => { + closureOrder.push('async-close-start'); + // Simulate async operation + await new Promise(resolve => setTimeout(resolve, 10)); + closureOrder.push('async-close-end'); + closureOrder.push(sessionId); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + expect(tempSessionId).toBeDefined(); + + // DELETE the session + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + + // Give time for async callback to complete + await new Promise(resolve => setTimeout(resolve, 50)); + + expect(closureOrder).toEqual(['async-close-start', 'async-close-end', tempSessionId]); + + // Clean up + tempServer.close(); + }); + + it('should propagate errors from async onsessioninitialized callback', async () => { + const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + // Create server with async onsessioninitialized callback that throws + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (_sessionId: string) => { + throw new Error('Async initialization error'); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize should fail when callback throws + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + expect(initResponse.status).toBe(400); + + // Clean up + consoleErrorSpy.mockRestore(); + tempServer.close(); + }); + + it('should propagate errors from async onsessionclosed callback', async () => { + const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}); + + // Create server with async onsessionclosed callback that throws + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessionclosed: async (_sessionId: string) => { + throw new Error('Async closure error'); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to get a session ID + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + // DELETE should fail when callback throws + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(500); + + // Clean up + consoleErrorSpy.mockRestore(); + tempServer.close(); + }); + + it('should handle both async callbacks together', async () => { + const events: string[] = []; + + // Create server with both async callbacks + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: async (sessionId: string) => { + await new Promise(resolve => setTimeout(resolve, 5)); + events.push(`initialized:${sessionId}`); + }, + onsessionclosed: async (sessionId: string) => { + await new Promise(resolve => setTimeout(resolve, 5)); + events.push(`closed:${sessionId}`); + } + }); + + const tempServer = result.server; + const tempUrl = result.baseUrl; + + // Initialize to trigger first callback + const initResponse = await sendPostRequest(tempUrl, TEST_MESSAGES.initialize); + const tempSessionId = initResponse.headers.get('mcp-session-id'); + + // Wait for async callback + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(events).toContain(`initialized:${tempSessionId}`); + + // DELETE to trigger second callback + const deleteResponse = await fetch(tempUrl, { + method: 'DELETE', + headers: { + 'mcp-session-id': tempSessionId || '', + 'mcp-protocol-version': '2025-03-26' + } + }); + + expect(deleteResponse.status).toBe(200); + + // Wait for async callback + await new Promise(resolve => setTimeout(resolve, 20)); + + expect(events).toContain(`closed:${tempSessionId}`); + expect(events).toHaveLength(2); + + // Clean up + tempServer.close(); + }); +}); + +// Test DNS rebinding protection +describe('StreamableHTTPServerTransport DNS rebinding protection', () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + + afterEach(async () => { + if (server && transport) { + await stopTestServer({ server, transport }); + } + }); + + describe('Host header validation', () => { + it('should accept requests with allowed host headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Note: fetch() automatically sets Host header to match the URL + // Since we're connecting to localhost:3001 and that's in allowedHosts, this should work + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(200); + }); + + it('should reject requests with disallowed host headers', async () => { + // Test DNS rebinding protection by creating a server that only allows example.com + // but we're connecting via localhost, so it should be rejected + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toContain('Invalid Host header:'); + }); + + it('should reject GET requests with disallowed host headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['example.com:3001'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'GET', + headers: { + Accept: 'text/event-stream' + } + }); + + expect(response.status).toBe(403); + }); + }); + + describe('Origin header validation', () => { + it('should accept requests with allowed origin headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000', 'https://example.com'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://localhost:3000' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(200); + }); + + it('should reject requests with disallowed origin headers', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://evil.com' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response.status).toBe(403); + const body = await response.json(); + expect(body.error.message).toBe('Invalid Origin header: http://evil.com'); + }); + }); + + describe('enableDnsRebindingProtection option', () => { + it('should skip all validations when enableDnsRebindingProtection is false', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + allowedOrigins: ['http://localhost:3000'], + enableDnsRebindingProtection: false + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + const response = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Host: 'evil.com', + Origin: 'http://evil.com' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + // Should pass even with invalid headers because protection is disabled + expect(response.status).toBe(200); + }); + }); + + describe('Combined validations', () => { + it('should validate both host and origin when both are configured', async () => { + const result = await createTestServerWithDnsProtection({ + sessionIdGenerator: undefined, + allowedHosts: ['localhost'], + allowedOrigins: ['http://localhost:3001'], + enableDnsRebindingProtection: true + }); + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + + // Test with invalid origin (host will be automatically correct via fetch) + const response1 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://evil.com' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response1.status).toBe(403); + const body1 = await response1.json(); + expect(body1.error.message).toBe('Invalid Origin header: http://evil.com'); + + // Test with valid origin + const response2 = await fetch(baseUrl, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json, text/event-stream', + Origin: 'http://localhost:3001' + }, + body: JSON.stringify(TEST_MESSAGES.initialize) + }); + + expect(response2.status).toBe(200); + }); + }); +}); + +/** + * Helper to create test server with DNS rebinding protection options + */ +async function createTestServerWithDnsProtection(config: { + sessionIdGenerator: (() => string) | undefined; + allowedHosts?: string[]; + allowedOrigins?: string[]; + enableDnsRebindingProtection?: boolean; +}): Promise<{ + server: Server; + transport: StreamableHTTPServerTransport; + mcpServer: McpServer; + baseUrl: URL; +}> { + const mcpServer = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: { logging: {} } }); + + const port = await getFreePort(); + + if (config.allowedHosts) { + config.allowedHosts = config.allowedHosts.map(host => { + if (host.includes(':')) { + return host; + } + return `localhost:${port}`; + }); + } + + const transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: config.sessionIdGenerator, + allowedHosts: config.allowedHosts, + allowedOrigins: config.allowedOrigins, + enableDnsRebindingProtection: config.enableDnsRebindingProtection + }); + + await mcpServer.connect(transport); + + const httpServer = createServer(async (req, res) => { + if (req.method === 'POST') { + let body = ''; + req.on('data', chunk => (body += chunk)); + req.on('end', async () => { + const parsedBody = JSON.parse(body); + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res, parsedBody); + }); + } else { + await transport.handleRequest(req as IncomingMessage & { auth?: AuthInfo }, res); + } + }); + + await new Promise(resolve => { + httpServer.listen(port, () => resolve()); + }); + + const serverUrl = new URL(`http://localhost:${port}/`); + + return { + server: httpServer, + transport, + mcpServer, + baseUrl: serverUrl + }; +} diff --git a/src/server/v3/title.v3.test.ts b/src/server/v3/title.v3.test.ts new file mode 100644 index 000000000..2d99d5316 --- /dev/null +++ b/src/server/v3/title.v3.test.ts @@ -0,0 +1,224 @@ +import { Server } from '../index.js'; +import { Client } from '../../client/index.js'; +import { InMemoryTransport } from '../../inMemory.js'; +import * as z from 'zod/v3'; +import { McpServer, ResourceTemplate } from '../mcp.js'; + +describe('Title field backwards compatibility', () => { + it('should work with tools that have title', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register tool with title + server.registerTool( + 'test-tool', + { + title: 'Test Tool Display Name', + description: 'A test tool', + inputSchema: { + value: z.string() + } + }, + async () => ({ content: [{ type: 'text', text: 'result' }] }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const tools = await client.listTools(); + expect(tools.tools).toHaveLength(1); + expect(tools.tools[0].name).toBe('test-tool'); + expect(tools.tools[0].title).toBe('Test Tool Display Name'); + expect(tools.tools[0].description).toBe('A test tool'); + }); + + it('should work with tools without title', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register tool without title + server.tool('test-tool', 'A test tool', { value: z.string() }, async () => ({ content: [{ type: 'text', text: 'result' }] })); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const tools = await client.listTools(); + expect(tools.tools).toHaveLength(1); + expect(tools.tools[0].name).toBe('test-tool'); + expect(tools.tools[0].title).toBeUndefined(); + expect(tools.tools[0].description).toBe('A test tool'); + }); + + it('should work with prompts that have title using update', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register prompt with title by updating after creation + const prompt = server.prompt('test-prompt', 'A test prompt', async () => ({ + messages: [{ role: 'user', content: { type: 'text', text: 'test' } }] + })); + prompt.update({ title: 'Test Prompt Display Name' }); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const prompts = await client.listPrompts(); + expect(prompts.prompts).toHaveLength(1); + expect(prompts.prompts[0].name).toBe('test-prompt'); + expect(prompts.prompts[0].title).toBe('Test Prompt Display Name'); + expect(prompts.prompts[0].description).toBe('A test prompt'); + }); + + it('should work with prompts using registerPrompt', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register prompt with title using registerPrompt + server.registerPrompt( + 'test-prompt', + { + title: 'Test Prompt Display Name', + description: 'A test prompt', + argsSchema: { input: z.string() } + }, + async ({ input }) => ({ + messages: [ + { + role: 'user', + content: { type: 'text', text: `test: ${input}` } + } + ] + }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const prompts = await client.listPrompts(); + expect(prompts.prompts).toHaveLength(1); + expect(prompts.prompts[0].name).toBe('test-prompt'); + expect(prompts.prompts[0].title).toBe('Test Prompt Display Name'); + expect(prompts.prompts[0].description).toBe('A test prompt'); + expect(prompts.prompts[0].arguments).toHaveLength(1); + }); + + it('should work with resources using registerResource', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register resource with title using registerResource + server.registerResource( + 'test-resource', + 'https://example.com/test', + { + title: 'Test Resource Display Name', + description: 'A test resource', + mimeType: 'text/plain' + }, + async () => ({ + contents: [ + { + uri: 'https://example.com/test', + text: 'test content' + } + ] + }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const resources = await client.listResources(); + expect(resources.resources).toHaveLength(1); + expect(resources.resources[0].name).toBe('test-resource'); + expect(resources.resources[0].title).toBe('Test Resource Display Name'); + expect(resources.resources[0].description).toBe('A test resource'); + expect(resources.resources[0].mimeType).toBe('text/plain'); + }); + + it('should work with dynamic resources using registerResource', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new McpServer({ name: 'test-server', version: '1.0.0' }, { capabilities: {} }); + + // Register dynamic resource with title using registerResource + server.registerResource( + 'user-profile', + new ResourceTemplate('users://{userId}/profile', { list: undefined }), + { + title: 'User Profile', + description: 'User profile information' + }, + async (uri, { userId }, _extra) => ({ + contents: [ + { + uri: uri.href, + text: `Profile data for user ${userId}` + } + ] + }) + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.server.connect(serverTransport); + await client.connect(clientTransport); + + const resourceTemplates = await client.listResourceTemplates(); + expect(resourceTemplates.resourceTemplates).toHaveLength(1); + expect(resourceTemplates.resourceTemplates[0].name).toBe('user-profile'); + expect(resourceTemplates.resourceTemplates[0].title).toBe('User Profile'); + expect(resourceTemplates.resourceTemplates[0].description).toBe('User profile information'); + expect(resourceTemplates.resourceTemplates[0].uriTemplate).toBe('users://{userId}/profile'); + + // Test reading the resource + const readResult = await client.readResource({ uri: 'users://123/profile' }); + expect(readResult.contents).toHaveLength(1); + expect(readResult.contents).toEqual( + expect.arrayContaining([ + { + text: expect.stringContaining('Profile data for user 123'), + uri: 'users://123/profile' + } + ]) + ); + }); + + it('should support serverInfo with title', async () => { + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + const server = new Server( + { + name: 'test-server', + version: '1.0.0', + title: 'Test Server Display Name' + }, + { capabilities: {} } + ); + + const client = new Client({ name: 'test-client', version: '1.0.0' }); + + await server.connect(serverTransport); + await client.connect(clientTransport); + + const serverInfo = client.getServerVersion(); + expect(serverInfo?.name).toBe('test-server'); + expect(serverInfo?.version).toBe('1.0.0'); + expect(serverInfo?.title).toBe('Test Server Display Name'); + }); +}); diff --git a/src/server/zod-compat.ts b/src/server/zod-compat.ts new file mode 100644 index 000000000..956aca821 --- /dev/null +++ b/src/server/zod-compat.ts @@ -0,0 +1,280 @@ +// zod-compat.ts +// ---------------------------------------------------- +// Unified types + helpers to accept Zod v3 and v4 (Mini) +// ---------------------------------------------------- + +import type * as z3 from 'zod/v3'; +import type * as z4 from 'zod/v4/core'; + +import * as z3rt from 'zod/v3'; +import * as z4mini from 'zod/v4-mini'; + +// --- Unified schema types --- +export type AnySchema = z3.ZodTypeAny | z4.$ZodType; +export type AnyObjectSchema = z3.AnyZodObject | z4.$ZodObject | AnySchema; +export type ZodRawShapeCompat = Record; + +// --- Internal property access helpers --- +// These types help us safely access internal properties that differ between v3 and v4 +export interface ZodV3Internal { + _def?: { + typeName?: string; + value?: unknown; + values?: unknown[]; + shape?: Record | (() => Record); + description?: string; + }; + shape?: Record | (() => Record); + value?: unknown; +} + +export interface ZodV4Internal { + _zod?: { + def?: { + typeName?: string; + value?: unknown; + values?: unknown[]; + shape?: Record | (() => Record); + description?: string; + }; + }; + value?: unknown; +} + +// --- Type inference helpers --- +export type SchemaOutput = S extends z3.ZodTypeAny ? z3.infer : S extends z4.$ZodType ? z4.output : never; + +export type SchemaInput = S extends z3.ZodTypeAny ? z3.input : S extends z4.$ZodType ? z4.input : never; + +/** + * Infers the output type from a ZodRawShapeCompat (raw shape object). + * Maps over each key in the shape and infers the output type from each schema. + */ +export type ShapeOutput = { + [K in keyof Shape]: SchemaOutput; +}; + +// --- Runtime detection --- +export function isZ4Schema(s: AnySchema): s is z4.$ZodType { + // Present on Zod 4 (Classic & Mini) schemas; absent on Zod 3 + const schema = s as unknown as ZodV4Internal; + return !!schema._zod; +} + +// --- Schema construction --- +export function objectFromShape(shape: ZodRawShapeCompat): AnyObjectSchema { + const values = Object.values(shape); + if (values.length === 0) return z4mini.object({}); // default to v4 Mini + + const allV4 = values.every(isZ4Schema); + const allV3 = values.every(s => !isZ4Schema(s)); + + if (allV4) return z4mini.object(shape as Record); + if (allV3) return z3rt.object(shape as Record); + + throw new Error('Mixed Zod versions detected in object shape.'); +} + +// --- Unified parsing --- +export function safeParse( + schema: S, + data: unknown +): { success: true; data: SchemaOutput } | { success: false; error: unknown } { + if (isZ4Schema(schema)) { + // Mini exposes top-level safeParse + const result = z4mini.safeParse(schema, data); + return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; + } + const v3Schema = schema as z3.ZodTypeAny; + const result = v3Schema.safeParse(data); + return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; +} + +export async function safeParseAsync( + schema: S, + data: unknown +): Promise<{ success: true; data: SchemaOutput } | { success: false; error: unknown }> { + if (isZ4Schema(schema)) { + // Mini exposes top-level safeParseAsync + const result = await z4mini.safeParseAsync(schema, data); + return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; + } + const v3Schema = schema as z3.ZodTypeAny; + const result = await v3Schema.safeParseAsync(data); + return result as { success: true; data: SchemaOutput } | { success: false; error: unknown }; +} + +// --- Shape extraction --- +export function getObjectShape(schema: AnyObjectSchema | undefined): Record | undefined { + if (!schema) return undefined; + + // Zod v3 exposes `.shape`; Zod v4 keeps the shape on `_zod.def.shape` + let rawShape: Record | (() => Record) | undefined; + + if (isZ4Schema(schema)) { + const v4Schema = schema as unknown as ZodV4Internal; + rawShape = v4Schema._zod?.def?.shape; + } else { + const v3Schema = schema as unknown as ZodV3Internal; + rawShape = v3Schema.shape; + } + + if (!rawShape) return undefined; + + if (typeof rawShape === 'function') { + try { + return rawShape(); + } catch { + return undefined; + } + } + + return rawShape; +} + +// --- Schema normalization --- +/** + * Normalizes a schema to an object schema. Handles both: + * - Already-constructed object schemas (v3 or v4) + * - Raw shapes that need to be wrapped into object schemas + */ +export function normalizeObjectSchema(schema: AnySchema | ZodRawShapeCompat | undefined): AnyObjectSchema | undefined { + if (!schema) return undefined; + + // First check if it's a raw shape (Record) + // Raw shapes don't have _def or _zod properties and aren't schemas themselves + if (typeof schema === 'object') { + // Check if it's actually a ZodRawShapeCompat (not a schema instance) + // by checking if it lacks schema-like internal properties + const asV3 = schema as unknown as ZodV3Internal; + const asV4 = schema as unknown as ZodV4Internal; + + // If it's not a schema instance (no _def or _zod), it might be a raw shape + if (!asV3._def && !asV4._zod) { + // Check if all values are schemas (heuristic to confirm it's a raw shape) + const values = Object.values(schema); + if ( + values.length > 0 && + values.every( + v => + typeof v === 'object' && + v !== null && + ((v as unknown as ZodV3Internal)._def !== undefined || + (v as unknown as ZodV4Internal)._zod !== undefined || + typeof (v as { parse?: unknown }).parse === 'function') + ) + ) { + return objectFromShape(schema as ZodRawShapeCompat); + } + } + } + + // If we get here, it should be an AnySchema (not a raw shape) + // Check if it's already an object schema + if (isZ4Schema(schema as AnySchema)) { + // Check if it's a v4 object + const v4Schema = schema as unknown as ZodV4Internal; + const def = v4Schema._zod?.def; + if (def && (def.typeName === 'object' || def.shape !== undefined)) { + return schema as AnyObjectSchema; + } + } else { + // Check if it's a v3 object + const v3Schema = schema as unknown as ZodV3Internal; + if (v3Schema.shape !== undefined) { + return schema as AnyObjectSchema; + } + } + + return undefined; +} + +// --- Error message extraction --- +/** + * Safely extracts an error message from a parse result error. + * Zod errors can have different structures, so we handle various cases. + */ +export function getParseErrorMessage(error: unknown): string { + if (error && typeof error === 'object') { + // Try common error structures + if ('message' in error && typeof error.message === 'string') { + return error.message; + } + if ('issues' in error && Array.isArray(error.issues) && error.issues.length > 0) { + const firstIssue = error.issues[0]; + if (firstIssue && typeof firstIssue === 'object' && 'message' in firstIssue) { + return String(firstIssue.message); + } + } + // Fallback: try to stringify the error + try { + return JSON.stringify(error); + } catch { + return String(error); + } + } + return String(error); +} + +// --- Schema metadata access --- +/** + * Gets the description from a schema, if available. + * Works with both Zod v3 and v4. + */ +export function getSchemaDescription(schema: AnySchema): string | undefined { + if (isZ4Schema(schema)) { + const v4Schema = schema as unknown as ZodV4Internal; + return v4Schema._zod?.def?.description; + } + const v3Schema = schema as unknown as ZodV3Internal; + // v3 may have description on the schema itself or in _def + return (schema as { description?: string }).description ?? v3Schema._def?.description; +} + +/** + * Checks if a schema is optional. + * Works with both Zod v3 and v4. + */ +export function isSchemaOptional(schema: AnySchema): boolean { + if (isZ4Schema(schema)) { + const v4Schema = schema as unknown as ZodV4Internal; + return v4Schema._zod?.def?.typeName === 'ZodOptional'; + } + const v3Schema = schema as unknown as ZodV3Internal; + // v3 has isOptional() method + if (typeof (schema as { isOptional?: () => boolean }).isOptional === 'function') { + return (schema as { isOptional: () => boolean }).isOptional(); + } + return v3Schema._def?.typeName === 'ZodOptional'; +} + +/** + * Gets the literal value from a schema, if it's a literal schema. + * Works with both Zod v3 and v4. + * Returns undefined if the schema is not a literal or the value cannot be determined. + */ +export function getLiteralValue(schema: AnySchema): unknown { + if (isZ4Schema(schema)) { + const v4Schema = schema as unknown as ZodV4Internal; + const def = v4Schema._zod?.def; + if (def) { + // Try various ways to get the literal value + if (def.value !== undefined) return def.value; + if (Array.isArray(def.values) && def.values.length > 0) { + return def.values[0]; + } + } + } + const v3Schema = schema as unknown as ZodV3Internal; + const def = v3Schema._def; + if (def) { + if (def.value !== undefined) return def.value; + if (Array.isArray(def.values) && def.values.length > 0) { + return def.values[0]; + } + } + // Fallback: check for direct value property (some Zod versions) + const directValue = (schema as { value?: unknown }).value; + if (directValue !== undefined) return directValue; + return undefined; +} diff --git a/src/server/zod-json-schema-compat.ts b/src/server/zod-json-schema-compat.ts new file mode 100644 index 000000000..cde66b177 --- /dev/null +++ b/src/server/zod-json-schema-compat.ts @@ -0,0 +1,68 @@ +// zod-json-schema-compat.ts +// ---------------------------------------------------- +// JSON Schema conversion for both Zod v3 and Zod v4 (Mini) +// v3 uses your vendored converter; v4 uses Mini's toJSONSchema +// ---------------------------------------------------- + +import type * as z3 from 'zod/v3'; +import type * as z4c from 'zod/v4/core'; + +import * as z4mini from 'zod/v4-mini'; + +import { AnySchema, AnyObjectSchema, getObjectShape, safeParse, isZ4Schema, getLiteralValue } from './zod-compat.js'; +import { zodToJsonSchema } from 'zod-to-json-schema'; + +type JsonSchema = Record; + +// Options accepted by call sites; we map them appropriately +type CommonOpts = { + strictUnions?: boolean; + pipeStrategy?: 'input' | 'output'; + target?: 'jsonSchema7' | 'draft-7' | 'jsonSchema2019-09' | 'draft-2020-12'; +}; + +function mapMiniTarget(t: CommonOpts['target'] | undefined): 'draft-7' | 'draft-2020-12' { + if (!t) return 'draft-7'; + if (t === 'jsonSchema7' || t === 'draft-7') return 'draft-7'; + if (t === 'jsonSchema2019-09' || t === 'draft-2020-12') return 'draft-2020-12'; + return 'draft-7'; // fallback +} + +export function toJsonSchemaCompat(schema: AnyObjectSchema, opts?: CommonOpts): JsonSchema { + if (isZ4Schema(schema)) { + // v4 branch — use Mini's built-in toJSONSchema + return z4mini.toJSONSchema(schema as z4c.$ZodType, { + target: mapMiniTarget(opts?.target), + io: opts?.pipeStrategy ?? 'input' + }) as JsonSchema; + } + + // v3 branch — use vendored converter + return zodToJsonSchema(schema as z3.ZodTypeAny, { + strictUnions: opts?.strictUnions ?? true, + pipeStrategy: opts?.pipeStrategy ?? 'input' + }) as JsonSchema; +} + +export function getMethodLiteral(schema: AnyObjectSchema): string { + const shape = getObjectShape(schema); + const methodSchema = shape?.method as AnySchema | undefined; + if (!methodSchema) { + throw new Error('Schema is missing a method literal'); + } + + const value = getLiteralValue(methodSchema); + if (typeof value !== 'string') { + throw new Error('Schema method literal must be a string'); + } + + return value; +} + +export function parseWithCompat(schema: AnySchema, data: unknown): unknown { + const result = safeParse(schema, data); + if (!result.success) { + throw result.error; + } + return result.data; +} diff --git a/src/shared/auth.ts b/src/shared/auth.ts index 1274fcd61..b37a4c70c 100644 --- a/src/shared/auth.ts +++ b/src/shared/auth.ts @@ -1,10 +1,9 @@ -import { z } from 'zod'; +import * as z from 'zod/v4'; /** * Reusable URL validation that disallows javascript: scheme */ export const SafeUrlSchema = z - .string() .url() .superRefine((val, ctx) => { if (!URL.canParse(val)) { @@ -28,107 +27,102 @@ export const SafeUrlSchema = z /** * RFC 9728 OAuth Protected Resource Metadata */ -export const OAuthProtectedResourceMetadataSchema = z - .object({ - resource: z.string().url(), - authorization_servers: z.array(SafeUrlSchema).optional(), - jwks_uri: z.string().url().optional(), - scopes_supported: z.array(z.string()).optional(), - bearer_methods_supported: z.array(z.string()).optional(), - resource_signing_alg_values_supported: z.array(z.string()).optional(), - resource_name: z.string().optional(), - resource_documentation: z.string().optional(), - resource_policy_uri: z.string().url().optional(), - resource_tos_uri: z.string().url().optional(), - tls_client_certificate_bound_access_tokens: z.boolean().optional(), - authorization_details_types_supported: z.array(z.string()).optional(), - dpop_signing_alg_values_supported: z.array(z.string()).optional(), - dpop_bound_access_tokens_required: z.boolean().optional() - }) - .passthrough(); +export const OAuthProtectedResourceMetadataSchema = z.looseObject({ + resource: z.string().url(), + authorization_servers: z.array(SafeUrlSchema).optional(), + jwks_uri: z.string().url().optional(), + scopes_supported: z.array(z.string()).optional(), + bearer_methods_supported: z.array(z.string()).optional(), + resource_signing_alg_values_supported: z.array(z.string()).optional(), + resource_name: z.string().optional(), + resource_documentation: z.string().optional(), + resource_policy_uri: z.string().url().optional(), + resource_tos_uri: z.string().url().optional(), + tls_client_certificate_bound_access_tokens: z.boolean().optional(), + authorization_details_types_supported: z.array(z.string()).optional(), + dpop_signing_alg_values_supported: z.array(z.string()).optional(), + dpop_bound_access_tokens_required: z.boolean().optional() +}); /** * RFC 8414 OAuth 2.0 Authorization Server Metadata */ -export const OAuthMetadataSchema = z - .object({ - issuer: z.string(), - authorization_endpoint: SafeUrlSchema, - token_endpoint: SafeUrlSchema, - registration_endpoint: SafeUrlSchema.optional(), - scopes_supported: z.array(z.string()).optional(), - response_types_supported: z.array(z.string()), - response_modes_supported: z.array(z.string()).optional(), - grant_types_supported: z.array(z.string()).optional(), - token_endpoint_auth_methods_supported: z.array(z.string()).optional(), - token_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), - service_documentation: SafeUrlSchema.optional(), - revocation_endpoint: SafeUrlSchema.optional(), - revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), - revocation_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), - introspection_endpoint: z.string().optional(), - introspection_endpoint_auth_methods_supported: z.array(z.string()).optional(), - introspection_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), - code_challenge_methods_supported: z.array(z.string()).optional(), - client_id_metadata_document_supported: z.boolean().optional() - }) - .passthrough(); +export const OAuthMetadataSchema = z.looseObject({ + issuer: z.string(), + authorization_endpoint: SafeUrlSchema, + token_endpoint: SafeUrlSchema, + registration_endpoint: SafeUrlSchema.optional(), + scopes_supported: z.array(z.string()).optional(), + response_types_supported: z.array(z.string()), + response_modes_supported: z.array(z.string()).optional(), + grant_types_supported: z.array(z.string()).optional(), + token_endpoint_auth_methods_supported: z.array(z.string()).optional(), + token_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + service_documentation: SafeUrlSchema.optional(), + revocation_endpoint: SafeUrlSchema.optional(), + revocation_endpoint_auth_methods_supported: z.array(z.string()).optional(), + revocation_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + introspection_endpoint: z.string().optional(), + introspection_endpoint_auth_methods_supported: z.array(z.string()).optional(), + introspection_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + code_challenge_methods_supported: z.array(z.string()).optional(), + client_id_metadata_document_supported: z.boolean().optional() +}); /** * OpenID Connect Discovery 1.0 Provider Metadata * see: https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata */ -export const OpenIdProviderMetadataSchema = z - .object({ - issuer: z.string(), - authorization_endpoint: SafeUrlSchema, - token_endpoint: SafeUrlSchema, - userinfo_endpoint: SafeUrlSchema.optional(), - jwks_uri: SafeUrlSchema, - registration_endpoint: SafeUrlSchema.optional(), - scopes_supported: z.array(z.string()).optional(), - response_types_supported: z.array(z.string()), - response_modes_supported: z.array(z.string()).optional(), - grant_types_supported: z.array(z.string()).optional(), - acr_values_supported: z.array(z.string()).optional(), - subject_types_supported: z.array(z.string()), - id_token_signing_alg_values_supported: z.array(z.string()), - id_token_encryption_alg_values_supported: z.array(z.string()).optional(), - id_token_encryption_enc_values_supported: z.array(z.string()).optional(), - userinfo_signing_alg_values_supported: z.array(z.string()).optional(), - userinfo_encryption_alg_values_supported: z.array(z.string()).optional(), - userinfo_encryption_enc_values_supported: z.array(z.string()).optional(), - request_object_signing_alg_values_supported: z.array(z.string()).optional(), - request_object_encryption_alg_values_supported: z.array(z.string()).optional(), - request_object_encryption_enc_values_supported: z.array(z.string()).optional(), - token_endpoint_auth_methods_supported: z.array(z.string()).optional(), - token_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), - display_values_supported: z.array(z.string()).optional(), - claim_types_supported: z.array(z.string()).optional(), - claims_supported: z.array(z.string()).optional(), - service_documentation: z.string().optional(), - claims_locales_supported: z.array(z.string()).optional(), - ui_locales_supported: z.array(z.string()).optional(), - claims_parameter_supported: z.boolean().optional(), - request_parameter_supported: z.boolean().optional(), - request_uri_parameter_supported: z.boolean().optional(), - require_request_uri_registration: z.boolean().optional(), - op_policy_uri: SafeUrlSchema.optional(), - op_tos_uri: SafeUrlSchema.optional(), - client_id_metadata_document_supported: z.boolean().optional() - }) - .passthrough(); +export const OpenIdProviderMetadataSchema = z.looseObject({ + issuer: z.string(), + authorization_endpoint: SafeUrlSchema, + token_endpoint: SafeUrlSchema, + userinfo_endpoint: SafeUrlSchema.optional(), + jwks_uri: SafeUrlSchema, + registration_endpoint: SafeUrlSchema.optional(), + scopes_supported: z.array(z.string()).optional(), + response_types_supported: z.array(z.string()), + response_modes_supported: z.array(z.string()).optional(), + grant_types_supported: z.array(z.string()).optional(), + acr_values_supported: z.array(z.string()).optional(), + subject_types_supported: z.array(z.string()), + id_token_signing_alg_values_supported: z.array(z.string()), + id_token_encryption_alg_values_supported: z.array(z.string()).optional(), + id_token_encryption_enc_values_supported: z.array(z.string()).optional(), + userinfo_signing_alg_values_supported: z.array(z.string()).optional(), + userinfo_encryption_alg_values_supported: z.array(z.string()).optional(), + userinfo_encryption_enc_values_supported: z.array(z.string()).optional(), + request_object_signing_alg_values_supported: z.array(z.string()).optional(), + request_object_encryption_alg_values_supported: z.array(z.string()).optional(), + request_object_encryption_enc_values_supported: z.array(z.string()).optional(), + token_endpoint_auth_methods_supported: z.array(z.string()).optional(), + token_endpoint_auth_signing_alg_values_supported: z.array(z.string()).optional(), + display_values_supported: z.array(z.string()).optional(), + claim_types_supported: z.array(z.string()).optional(), + claims_supported: z.array(z.string()).optional(), + service_documentation: z.string().optional(), + claims_locales_supported: z.array(z.string()).optional(), + ui_locales_supported: z.array(z.string()).optional(), + claims_parameter_supported: z.boolean().optional(), + request_parameter_supported: z.boolean().optional(), + request_uri_parameter_supported: z.boolean().optional(), + require_request_uri_registration: z.boolean().optional(), + op_policy_uri: SafeUrlSchema.optional(), + op_tos_uri: SafeUrlSchema.optional(), + client_id_metadata_document_supported: z.boolean().optional() +}); /** * OpenID Connect Discovery metadata that may include OAuth 2.0 fields * This schema represents the real-world scenario where OIDC providers * return a mix of OpenID Connect and OAuth 2.0 metadata fields */ -export const OpenIdProviderDiscoveryMetadataSchema = OpenIdProviderMetadataSchema.merge( - OAuthMetadataSchema.pick({ +export const OpenIdProviderDiscoveryMetadataSchema = z.object({ + ...OpenIdProviderMetadataSchema.shape, + ...OAuthMetadataSchema.pick({ code_challenge_methods_supported: true - }) -); + }).shape +}); /** * OAuth 2.1 token response diff --git a/src/shared/protocol-transport-handling.test.ts b/src/shared/protocol-transport-handling.test.ts index 83181494f..b463d6db4 100644 --- a/src/shared/protocol-transport-handling.test.ts +++ b/src/shared/protocol-transport-handling.test.ts @@ -2,7 +2,7 @@ import { describe, expect, test, beforeEach } from 'vitest'; import { Protocol } from './protocol.js'; import { Transport } from './transport.js'; import { Request, Notification, Result, JSONRPCMessage } from '../types.js'; -import { z } from 'zod'; +import * as z from 'zod/v4'; // Mock Transport class class MockTransport implements Transport { diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 5141e201c..add69163c 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -1,4 +1,4 @@ -import { ZodLiteral, ZodObject, ZodType, z } from 'zod'; +import { AnySchema, AnyObjectSchema, SchemaOutput, safeParse } from '../server/zod-compat.js'; import { CancelledNotificationSchema, ClientCapabilities, @@ -27,6 +27,7 @@ import { } from '../types.js'; import { Transport, TransportSendOptions } from './transport.js'; import { AuthInfo } from '../server/auth/types.js'; +import { getMethodLiteral, parseWithCompat } from '../server/zod-json-schema-compat.js'; /** * Callback for progress notifications. @@ -152,7 +153,7 @@ export type RequestHandlerExtra>(request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; + sendRequest: (request: SendRequestT, resultSchema: U, options?: RequestOptions) => Promise>; }; /** @@ -490,7 +491,7 @@ export abstract class Protocol>(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { + request(request: SendRequestT, resultSchema: T, options?: RequestOptions): Promise> { const { relatedRequestId, resumptionToken, onresumptiontoken } = options ?? {}; return new Promise((resolve, reject) => { @@ -555,8 +556,13 @@ export abstract class Protocol); + } } catch (error) { reject(error); } @@ -639,19 +645,19 @@ export abstract class Protocol; - }> - >( + setRequestHandler( requestSchema: T, - handler: (request: z.infer, extra: RequestHandlerExtra) => SendResultT | Promise + handler: ( + request: SchemaOutput, + extra: RequestHandlerExtra + ) => SendResultT | Promise ): void { - const method = requestSchema.shape.method.value; + const method = getMethodLiteral(requestSchema); this.assertRequestHandlerCapability(method); this._requestHandlers.set(method, (request, extra) => { - return Promise.resolve(handler(requestSchema.parse(request), extra)); + const parsed = parseWithCompat(requestSchema, request) as SchemaOutput; + return Promise.resolve(handler(parsed, extra)); }); } @@ -676,14 +682,15 @@ export abstract class Protocol; - }> - >(notificationSchema: T, handler: (notification: z.infer) => void | Promise): void { - this._notificationHandlers.set(notificationSchema.shape.method.value, notification => - Promise.resolve(handler(notificationSchema.parse(notification))) - ); + setNotificationHandler( + notificationSchema: T, + handler: (notification: SchemaOutput) => void | Promise + ): void { + const method = getMethodLiteral(notificationSchema); + this._notificationHandlers.set(method, notification => { + const parsed = parseWithCompat(notificationSchema, notification) as SchemaOutput; + return Promise.resolve(handler(parsed)); + }); } /** diff --git a/src/spec.types.test.ts b/src/spec.types.test.ts index 3df41bfc5..1c0b6ab5d 100644 --- a/src/spec.types.test.ts +++ b/src/spec.types.test.ts @@ -76,88 +76,97 @@ type FixSpecInitializeRequest = T extends { params: infer P } ? Omit = T extends { params: infer P } ? Omit & { params: FixSpecInitializeRequestParams

} : T; const sdkTypeChecks = { - RequestParams: (sdk: SDKTypes.RequestParams, spec: SpecTypes.RequestParams) => { + RequestParams: (sdk: RemovePassthrough, spec: SpecTypes.RequestParams) => { sdk = spec; spec = sdk; }, - NotificationParams: (sdk: SDKTypes.NotificationParams, spec: SpecTypes.NotificationParams) => { + NotificationParams: (sdk: RemovePassthrough, spec: SpecTypes.NotificationParams) => { sdk = spec; spec = sdk; }, - CancelledNotificationParams: (sdk: SDKTypes.CancelledNotificationParams, spec: SpecTypes.CancelledNotificationParams) => { + CancelledNotificationParams: ( + sdk: RemovePassthrough, + spec: SpecTypes.CancelledNotificationParams + ) => { sdk = spec; spec = sdk; }, InitializeRequestParams: ( - sdk: SDKTypes.InitializeRequestParams, + sdk: RemovePassthrough, spec: FixSpecInitializeRequestParams ) => { sdk = spec; spec = sdk; }, - ProgressNotificationParams: (sdk: SDKTypes.ProgressNotificationParams, spec: SpecTypes.ProgressNotificationParams) => { + ProgressNotificationParams: ( + sdk: RemovePassthrough, + spec: SpecTypes.ProgressNotificationParams + ) => { sdk = spec; spec = sdk; }, - ResourceRequestParams: (sdk: SDKTypes.ResourceRequestParams, spec: SpecTypes.ResourceRequestParams) => { + ResourceRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.ResourceRequestParams) => { sdk = spec; spec = sdk; }, - ReadResourceRequestParams: (sdk: SDKTypes.ReadResourceRequestParams, spec: SpecTypes.ReadResourceRequestParams) => { + ReadResourceRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.ReadResourceRequestParams) => { sdk = spec; spec = sdk; }, - SubscribeRequestParams: (sdk: SDKTypes.SubscribeRequestParams, spec: SpecTypes.SubscribeRequestParams) => { + SubscribeRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.SubscribeRequestParams) => { sdk = spec; spec = sdk; }, - UnsubscribeRequestParams: (sdk: SDKTypes.UnsubscribeRequestParams, spec: SpecTypes.UnsubscribeRequestParams) => { + UnsubscribeRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.UnsubscribeRequestParams) => { sdk = spec; spec = sdk; }, ResourceUpdatedNotificationParams: ( - sdk: SDKTypes.ResourceUpdatedNotificationParams, + sdk: RemovePassthrough, spec: SpecTypes.ResourceUpdatedNotificationParams ) => { sdk = spec; spec = sdk; }, - GetPromptRequestParams: (sdk: SDKTypes.GetPromptRequestParams, spec: SpecTypes.GetPromptRequestParams) => { + GetPromptRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.GetPromptRequestParams) => { sdk = spec; spec = sdk; }, - CallToolRequestParams: (sdk: SDKTypes.CallToolRequestParams, spec: SpecTypes.CallToolRequestParams) => { + CallToolRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.CallToolRequestParams) => { sdk = spec; spec = sdk; }, - SetLevelRequestParams: (sdk: SDKTypes.SetLevelRequestParams, spec: SpecTypes.SetLevelRequestParams) => { + SetLevelRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.SetLevelRequestParams) => { sdk = spec; spec = sdk; }, LoggingMessageNotificationParams: ( - sdk: MakeUnknownsNotOptional, + sdk: MakeUnknownsNotOptional>, spec: SpecTypes.LoggingMessageNotificationParams ) => { sdk = spec; spec = sdk; }, - CreateMessageRequestParams: (sdk: SDKTypes.CreateMessageRequestParams, spec: SpecTypes.CreateMessageRequestParams) => { + CreateMessageRequestParams: ( + sdk: RemovePassthrough, + spec: SpecTypes.CreateMessageRequestParams + ) => { sdk = spec; spec = sdk; }, - CompleteRequestParams: (sdk: SDKTypes.CompleteRequestParams, spec: SpecTypes.CompleteRequestParams) => { + CompleteRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.CompleteRequestParams) => { sdk = spec; spec = sdk; }, - ElicitRequestParams: (sdk: SDKTypes.ElicitRequestParams, spec: SpecTypes.ElicitRequestParams) => { + ElicitRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.ElicitRequestParams) => { sdk = spec; spec = sdk; }, - ElicitRequestFormParams: (sdk: SDKTypes.ElicitRequestFormParams, spec: SpecTypes.ElicitRequestFormParams) => { + ElicitRequestFormParams: (sdk: RemovePassthrough, spec: SpecTypes.ElicitRequestFormParams) => { sdk = spec; spec = sdk; }, - ElicitRequestURLParams: (sdk: SDKTypes.ElicitRequestURLParams, spec: SpecTypes.ElicitRequestURLParams) => { + ElicitRequestURLParams: (sdk: RemovePassthrough, spec: SpecTypes.ElicitRequestURLParams) => { sdk = spec; spec = sdk; }, @@ -168,11 +177,11 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - PaginatedRequestParams: (sdk: SDKTypes.PaginatedRequestParams, spec: SpecTypes.PaginatedRequestParams) => { + PaginatedRequestParams: (sdk: RemovePassthrough, spec: SpecTypes.PaginatedRequestParams) => { sdk = spec; spec = sdk; }, - CancelledNotification: (sdk: WithJSONRPC, spec: SpecTypes.CancelledNotification) => { + CancelledNotification: (sdk: RemovePassthrough>, spec: SpecTypes.CancelledNotification) => { sdk = spec; spec = sdk; }, @@ -216,7 +225,7 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - ElicitRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.ElicitRequest) => { + ElicitRequest: (sdk: RemovePassthrough>, spec: SpecTypes.ElicitRequest) => { sdk = spec; spec = sdk; }, @@ -224,7 +233,7 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - CompleteRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.CompleteRequest) => { + CompleteRequest: (sdk: RemovePassthrough>, spec: SpecTypes.CompleteRequest) => { sdk = spec; spec = sdk; }, @@ -351,11 +360,11 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - SamplingMessage: (sdk: SDKTypes.SamplingMessage, spec: SpecTypes.SamplingMessage) => { + SamplingMessage: (sdk: RemovePassthrough, spec: SpecTypes.SamplingMessage) => { sdk = spec; spec = sdk; }, - CreateMessageResult: (sdk: SDKTypes.CreateMessageResult, spec: SpecTypes.CreateMessageResult) => { + CreateMessageResult: (sdk: RemovePassthrough, spec: SpecTypes.CreateMessageResult) => { sdk = spec; spec = sdk; }, @@ -535,12 +544,15 @@ const sdkTypeChecks = { sdk = spec; spec = sdk; }, - CreateMessageRequest: (sdk: WithJSONRPCRequest, spec: SpecTypes.CreateMessageRequest) => { + CreateMessageRequest: ( + sdk: RemovePassthrough>, + spec: SpecTypes.CreateMessageRequest + ) => { sdk = spec; spec = sdk; }, InitializeRequest: ( - sdk: WithJSONRPCRequest, + sdk: RemovePassthrough>, spec: FixSpecInitializeRequest ) => { sdk = spec; @@ -602,6 +614,25 @@ const sdkTypeChecks = { ModelPreferences: (sdk: SDKTypes.ModelPreferences, spec: SpecTypes.ModelPreferences) => { sdk = spec; spec = sdk; + }, + ToolChoice: (sdk: SDKTypes.ToolChoice, spec: SpecTypes.ToolChoice) => { + sdk = spec; + spec = sdk; + }, + ToolUseContent: (sdk: RemovePassthrough, spec: SpecTypes.ToolUseContent) => { + sdk = spec; + spec = sdk; + }, + ToolResultContent: (sdk: RemovePassthrough, spec: SpecTypes.ToolResultContent) => { + sdk = spec; + spec = sdk; + }, + SamplingMessageContentBlock: ( + sdk: RemovePassthrough, + spec: SpecTypes.SamplingMessageContentBlock + ) => { + sdk = spec; + spec = sdk; } }; @@ -631,7 +662,7 @@ describe('Spec Types', () => { it('should define some expected types', () => { expect(specTypes).toContain('JSONRPCNotification'); expect(specTypes).toContain('ElicitResult'); - expect(specTypes).toHaveLength(123); + expect(specTypes).toHaveLength(127); }); it('should have up to date list of missing sdk types', () => { diff --git a/src/spec.types.ts b/src/spec.types.ts index 307884fa0..6ce24059e 100644 --- a/src/spec.types.ts +++ b/src/spec.types.ts @@ -307,7 +307,17 @@ export interface ClientCapabilities { /** * Present if the client supports sampling from an LLM. */ - sampling?: object; + sampling?: { + /** + * Whether the client supports context inclusion via includeContext parameter. + * If not declared, servers SHOULD only use `includeContext: "none"` (or omit it). + */ + context?: object; + /** + * Whether the client supports tool use via tools and toolChoice parameters. + */ + tools?: object; + }; /** * Present if the client supports elicitation from the server. */ @@ -1255,7 +1265,11 @@ export interface CreateMessageRequestParams extends RequestParams { */ systemPrompt?: string; /** - * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request. + * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. + * The client MAY ignore this request. + * + * Default is "none". Values "thisServer" and "allServers" are soft-deprecated. Servers SHOULD only use these values if the client + * declares ClientCapabilities.sampling.context. These values may be removed in future spec releases. */ includeContext?: "none" | "thisServer" | "allServers"; /** @@ -1273,6 +1287,32 @@ export interface CreateMessageRequestParams extends RequestParams { * Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific. */ metadata?: object; + /** + * Tools that the model may use during generation. + * The client MUST return an error if this field is provided but ClientCapabilities.sampling.tools is not declared. + */ + tools?: Tool[]; + /** + * Controls how the model uses tools. + * The client MUST return an error if this field is provided but ClientCapabilities.sampling.tools is not declared. + * Default is `{ mode: "auto" }`. + */ + toolChoice?: ToolChoice; +} + +/** + * Controls tool selection behavior for sampling requests. + * + * @category `sampling/createMessage` + */ +export interface ToolChoice { + /** + * Controls the tool use ability of the model: + * - "auto": Model decides whether to use tools (default) + * - "required": Model MUST use at least one tool before completing + * - "none": Model MUST NOT use any tools + */ + mode?: "auto" | "required" | "none"; } /** @@ -1295,10 +1335,19 @@ export interface CreateMessageResult extends Result, SamplingMessage { * The name of the model that generated the message. */ model: string; + /** * The reason why sampling stopped, if known. + * + * Standard values: + * - "endTurn": Natural end of the assistant's turn + * - "stopSequence": A stop sequence was encountered + * - "maxTokens": Maximum token limit was reached + * - "toolUse": The model wants to use one or more tools + * + * This field is an open string to allow for provider-specific stop reasons. */ - stopReason?: "endTurn" | "stopSequence" | "maxTokens" | string; + stopReason?: "endTurn" | "stopSequence" | "maxTokens" | "toolUse" | string; } /** @@ -1308,8 +1357,18 @@ export interface CreateMessageResult extends Result, SamplingMessage { */ export interface SamplingMessage { role: Role; - content: TextContent | ImageContent | AudioContent; + content: SamplingMessageContentBlock | SamplingMessageContentBlock[]; + /** + * See [General fields: `_meta`](/specification/draft/basic/index#meta) for notes on `_meta` usage. + */ + _meta?: { [key: string]: unknown }; } +export type SamplingMessageContentBlock = + | TextContent + | ImageContent + | AudioContent + | ToolUseContent + | ToolResultContent; /** * Optional annotations for the client. The client can use annotations to inform how objects are used or displayed @@ -1444,6 +1503,87 @@ export interface AudioContent { _meta?: { [key: string]: unknown }; } +/** + * A request from the assistant to call a tool. + * + * @category `sampling/createMessage` + */ +export interface ToolUseContent { + type: "tool_use"; + + /** + * A unique identifier for this tool use. + * + * This ID is used to match tool results to their corresponding tool uses. + */ + id: string; + + /** + * The name of the tool to call. + */ + name: string; + + /** + * The arguments to pass to the tool, conforming to the tool's input schema. + */ + input: { [key: string]: unknown }; + + /** + * Optional metadata about the tool use. Clients SHOULD preserve this field when + * including tool uses in subsequent sampling requests to enable caching optimizations. + * + * See [General fields: `_meta`](/specification/draft/basic/index#meta) for notes on `_meta` usage. + */ + _meta?: { [key: string]: unknown }; +} + +/** + * The result of a tool use, provided by the user back to the assistant. + * + * @category `sampling/createMessage` + */ +export interface ToolResultContent { + type: "tool_result"; + + /** + * The ID of the tool use this result corresponds to. + * + * This MUST match the ID from a previous ToolUseContent. + */ + toolUseId: string; + + /** + * The unstructured result content of the tool use. + * + * This has the same format as CallToolResult.content and can include text, images, + * audio, resource links, and embedded resources. + */ + content: ContentBlock[]; + + /** + * An optional structured result object. + * + * If the tool defined an outputSchema, this SHOULD conform to that schema. + */ + structuredContent?: { [key: string]: unknown }; + + /** + * Whether the tool use resulted in an error. + * + * If true, the content typically describes the error that occurred. + * Default: false + */ + isError?: boolean; + + /** + * Optional metadata about the tool result. Clients SHOULD preserve this field when + * including tool results in subsequent sampling requests to enable caching optimizations. + * + * See [General fields: `_meta`](/specification/draft/basic/index#meta) for notes on `_meta` usage. + */ + _meta?: { [key: string]: unknown }; +} + /** * The server's preferences for model selection, requested of the client during sampling. * @@ -1762,7 +1902,6 @@ export interface ElicitRequest extends JSONRPCRequest { params: ElicitRequestParams; } -/** /** * Restricted schema definitions that only allow primitive types * without nested objects or arrays. diff --git a/src/types.test.ts b/src/types.test.ts index cd8cc0711..3f6f83a14 100644 --- a/src/types.test.ts +++ b/src/types.test.ts @@ -5,7 +5,14 @@ import { ContentBlockSchema, PromptMessageSchema, CallToolResultSchema, - CompleteRequestSchema + CompleteRequestSchema, + ToolUseContentSchema, + ToolResultContentSchema, + ToolChoiceSchema, + SamplingMessageSchema, + CreateMessageRequestSchema, + CreateMessageResultSchema, + ClientCapabilitiesSchema } from './types.js'; describe('Types', () => { @@ -311,4 +318,475 @@ describe('Types', () => { } }); }); + + describe('ToolUseContent', () => { + test('should validate a tool call content', () => { + const toolCall = { + type: 'tool_use', + id: 'call_123', + name: 'get_weather', + input: { city: 'San Francisco', units: 'celsius' } + }; + + const result = ToolUseContentSchema.safeParse(toolCall); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.type).toBe('tool_use'); + expect(result.data.id).toBe('call_123'); + expect(result.data.name).toBe('get_weather'); + expect(result.data.input).toEqual({ city: 'San Francisco', units: 'celsius' }); + } + }); + + test('should validate tool call with _meta', () => { + const toolCall = { + type: 'tool_use', + id: 'call_456', + name: 'search', + input: { query: 'test' }, + _meta: { custom: 'data' } + }; + + const result = ToolUseContentSchema.safeParse(toolCall); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data._meta).toEqual({ custom: 'data' }); + } + }); + + test('should fail validation for missing required fields', () => { + const invalidToolCall = { + type: 'tool_use', + name: 'test' + // missing id and input + }; + + const result = ToolUseContentSchema.safeParse(invalidToolCall); + expect(result.success).toBe(false); + }); + }); + + describe('ToolResultContent', () => { + test('should validate a tool result content', () => { + const toolResult = { + type: 'tool_result', + toolUseId: 'call_123', + structuredContent: { temperature: 72, condition: 'sunny' } + }; + + const result = ToolResultContentSchema.safeParse(toolResult); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.type).toBe('tool_result'); + expect(result.data.toolUseId).toBe('call_123'); + expect(result.data.structuredContent).toEqual({ temperature: 72, condition: 'sunny' }); + } + }); + + test('should validate tool result with error in content', () => { + const toolResult = { + type: 'tool_result', + toolUseId: 'call_456', + structuredContent: { error: 'API_ERROR', message: 'Service unavailable' }, + isError: true + }; + + const result = ToolResultContentSchema.safeParse(toolResult); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.structuredContent).toEqual({ error: 'API_ERROR', message: 'Service unavailable' }); + expect(result.data.isError).toBe(true); + } + }); + + test('should fail validation for missing required fields', () => { + const invalidToolResult = { + type: 'tool_result', + content: { data: 'test' } + // missing toolUseId + }; + + const result = ToolResultContentSchema.safeParse(invalidToolResult); + expect(result.success).toBe(false); + }); + }); + + describe('ToolChoice', () => { + test('should validate tool choice with mode auto', () => { + const toolChoice = { + mode: 'auto' + }; + + const result = ToolChoiceSchema.safeParse(toolChoice); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.mode).toBe('auto'); + } + }); + + test('should validate tool choice with mode required', () => { + const toolChoice = { + mode: 'required' + }; + + const result = ToolChoiceSchema.safeParse(toolChoice); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.mode).toBe('required'); + } + }); + + test('should validate empty tool choice', () => { + const toolChoice = {}; + + const result = ToolChoiceSchema.safeParse(toolChoice); + expect(result.success).toBe(true); + }); + + test('should fail validation for invalid mode', () => { + const invalidToolChoice = { + mode: 'invalid' + }; + + const result = ToolChoiceSchema.safeParse(invalidToolChoice); + expect(result.success).toBe(false); + }); + }); + + describe('SamplingMessage content types', () => { + test('should validate user message with text', () => { + const userMessage = { + role: 'user', + content: { type: 'text', text: "What's the weather?" } + }; + + const result = SamplingMessageSchema.safeParse(userMessage); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe('user'); + if (!Array.isArray(result.data.content)) { + expect(result.data.content.type).toBe('text'); + } + } + }); + + test('should validate user message with tool result', () => { + const userMessage = { + role: 'user', + content: { + type: 'tool_result', + toolUseId: 'call_123', + content: [] + } + }; + + const result = SamplingMessageSchema.safeParse(userMessage); + expect(result.success).toBe(true); + if (result.success && !Array.isArray(result.data.content)) { + expect(result.data.content.type).toBe('tool_result'); + } + }); + + test('should validate assistant message with text', () => { + const assistantMessage = { + role: 'assistant', + content: { type: 'text', text: "I'll check the weather for you." } + }; + + const result = SamplingMessageSchema.safeParse(assistantMessage); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe('assistant'); + } + }); + + test('should validate assistant message with tool call', () => { + const assistantMessage = { + role: 'assistant', + content: { + type: 'tool_use', + id: 'call_123', + name: 'get_weather', + input: { city: 'SF' } + } + }; + + const result = SamplingMessageSchema.safeParse(assistantMessage); + expect(result.success).toBe(true); + if (result.success && !Array.isArray(result.data.content)) { + expect(result.data.content.type).toBe('tool_use'); + } + }); + + test('should validate any content type for any role', () => { + // The simplified schema allows any content type for any role + const assistantWithToolResult = { + role: 'assistant', + content: { + type: 'tool_result', + toolUseId: 'call_123', + content: [] + } + }; + + const result1 = SamplingMessageSchema.safeParse(assistantWithToolResult); + expect(result1.success).toBe(true); + + const userWithToolUse = { + role: 'user', + content: { + type: 'tool_use', + id: 'call_123', + name: 'test', + input: {} + } + }; + + const result2 = SamplingMessageSchema.safeParse(userWithToolUse); + expect(result2.success).toBe(true); + }); + }); + + describe('SamplingMessage', () => { + test('should validate user message via discriminated union', () => { + const message = { + role: 'user', + content: { type: 'text', text: 'Hello' } + }; + + const result = SamplingMessageSchema.safeParse(message); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe('user'); + } + }); + + test('should validate assistant message via discriminated union', () => { + const message = { + role: 'assistant', + content: { type: 'text', text: 'Hi there!' } + }; + + const result = SamplingMessageSchema.safeParse(message); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.role).toBe('assistant'); + } + }); + }); + + describe('CreateMessageRequest', () => { + test('should validate request without tools', () => { + const request = { + method: 'sampling/createMessage', + params: { + messages: [{ role: 'user', content: { type: 'text', text: 'Hello' } }], + maxTokens: 1000 + } + }; + + const result = CreateMessageRequestSchema.safeParse(request); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.params.tools).toBeUndefined(); + } + }); + + test('should validate request with tools', () => { + const request = { + method: 'sampling/createMessage', + params: { + messages: [{ role: 'user', content: { type: 'text', text: "What's the weather?" } }], + maxTokens: 1000, + tools: [ + { + name: 'get_weather', + description: 'Get weather for a location', + inputSchema: { + type: 'object', + properties: { + location: { type: 'string' } + }, + required: ['location'] + } + } + ], + toolChoice: { + mode: 'auto' + } + } + }; + + const result = CreateMessageRequestSchema.safeParse(request); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.params.tools).toHaveLength(1); + expect(result.data.params.toolChoice?.mode).toBe('auto'); + } + }); + + test('should validate request with includeContext (soft-deprecated)', () => { + const request = { + method: 'sampling/createMessage', + params: { + messages: [{ role: 'user', content: { type: 'text', text: 'Help' } }], + maxTokens: 1000, + includeContext: 'thisServer' + } + }; + + const result = CreateMessageRequestSchema.safeParse(request); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.params.includeContext).toBe('thisServer'); + } + }); + }); + + describe('CreateMessageResult', () => { + test('should validate result with text content', () => { + const result = { + model: 'claude-3-5-sonnet-20241022', + role: 'assistant', + content: { type: 'text', text: "Here's the answer." }, + stopReason: 'endTurn' + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + if (parseResult.success) { + expect(parseResult.data.role).toBe('assistant'); + expect(parseResult.data.stopReason).toBe('endTurn'); + } + }); + + test('should validate result with tool call', () => { + const result = { + model: 'claude-3-5-sonnet-20241022', + role: 'assistant', + content: { + type: 'tool_use', + id: 'call_123', + name: 'get_weather', + input: { city: 'SF' } + }, + stopReason: 'toolUse' + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + if (parseResult.success) { + expect(parseResult.data.stopReason).toBe('toolUse'); + const content = parseResult.data.content; + expect(Array.isArray(content)).toBe(false); + if (!Array.isArray(content)) { + expect(content.type).toBe('tool_use'); + } + } + }); + + test('should validate result with array content', () => { + const result = { + model: 'claude-3-5-sonnet-20241022', + role: 'assistant', + content: [ + { type: 'text', text: 'Let me check the weather.' }, + { + type: 'tool_use', + id: 'call_123', + name: 'get_weather', + input: { city: 'SF' } + } + ], + stopReason: 'toolUse' + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + if (parseResult.success) { + expect(parseResult.data.stopReason).toBe('toolUse'); + const content = parseResult.data.content; + expect(Array.isArray(content)).toBe(true); + if (Array.isArray(content)) { + expect(content).toHaveLength(2); + expect(content[0].type).toBe('text'); + expect(content[1].type).toBe('tool_use'); + } + } + }); + + test('should validate all new stop reasons', () => { + const stopReasons = ['endTurn', 'stopSequence', 'maxTokens', 'toolUse', 'refusal', 'other']; + + stopReasons.forEach(stopReason => { + const result = { + model: 'test', + role: 'assistant', + content: { type: 'text', text: 'test' }, + stopReason + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + }); + }); + + test('should allow custom stop reason string', () => { + const result = { + model: 'test', + role: 'assistant', + content: { type: 'text', text: 'test' }, + stopReason: 'custom_provider_reason' + }; + + const parseResult = CreateMessageResultSchema.safeParse(result); + expect(parseResult.success).toBe(true); + }); + }); + + describe('ClientCapabilities with sampling', () => { + test('should validate capabilities with sampling.tools', () => { + const capabilities = { + sampling: { + tools: {} + } + }; + + const result = ClientCapabilitiesSchema.safeParse(capabilities); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.sampling?.tools).toBeDefined(); + } + }); + + test('should validate capabilities with sampling.context', () => { + const capabilities = { + sampling: { + context: {} + } + }; + + const result = ClientCapabilitiesSchema.safeParse(capabilities); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.sampling?.context).toBeDefined(); + } + }); + + test('should validate capabilities with both', () => { + const capabilities = { + sampling: { + context: {}, + tools: {} + } + }; + + const result = ClientCapabilitiesSchema.safeParse(capabilities); + expect(result.success).toBe(true); + if (result.success) { + expect(result.data.sampling?.context).toBeDefined(); + expect(result.data.sampling?.tools).toBeDefined(); + } + }); + }); }); diff --git a/src/types.ts b/src/types.ts index 78fa81d54..64153094d 100644 --- a/src/types.ts +++ b/src/types.ts @@ -1,4 +1,4 @@ -import { z, ZodTypeAny } from 'zod'; +import * as z from 'zod/v4'; import { AuthInfo } from './server/auth/types.js'; export const LATEST_PROTOCOL_VERSION = '2025-06-18'; @@ -28,22 +28,17 @@ export const ProgressTokenSchema = z.union([z.string(), z.number().int()]); */ export const CursorSchema = z.string(); -const RequestMetaSchema = z - .object({ - /** - * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. - */ - progressToken: ProgressTokenSchema.optional() - }) +const RequestMetaSchema = z.looseObject({ /** - * Passthrough required here because we want to allow any additional fields to be added to the request meta. + * If specified, the caller is requesting out-of-band progress notifications for this request (as represented by notifications/progress). The value of this parameter is an opaque token that will be attached to any subsequent notifications. The receiver is not obligated to provide these notifications. */ - .passthrough(); + progressToken: ProgressTokenSchema.optional() +}); /** * Common params for any request. */ -const BaseRequestParamsSchema = z.object({ +const BaseRequestParamsSchema = z.looseObject({ /** * See [General fields: `_meta`](/specification/draft/basic/index#meta) for notes on `_meta` usage. */ @@ -52,10 +47,10 @@ const BaseRequestParamsSchema = z.object({ export const RequestSchema = z.object({ method: z.string(), - params: BaseRequestParamsSchema.passthrough().optional() + params: BaseRequestParamsSchema.optional() }); -const NotificationsParamsSchema = z.object({ +const NotificationsParamsSchema = z.looseObject({ /** * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. @@ -65,21 +60,16 @@ const NotificationsParamsSchema = z.object({ export const NotificationSchema = z.object({ method: z.string(), - params: NotificationsParamsSchema.passthrough().optional() + params: NotificationsParamsSchema.optional() }); -export const ResultSchema = z - .object({ - /** - * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) - * for notes on _meta usage. - */ - _meta: z.record(z.string(), z.unknown()).optional() - }) +export const ResultSchema = z.looseObject({ /** - * Passthrough required here because we want to allow any additional fields to be added to the result. + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. */ - .passthrough(); + _meta: z.record(z.string(), z.unknown()).optional() +}); /** * A uniquely identifying ID for a request in JSON-RPC. @@ -92,9 +82,9 @@ export const RequestIdSchema = z.union([z.string(), z.number().int()]); export const JSONRPCRequestSchema = z .object({ jsonrpc: z.literal(JSONRPC_VERSION), - id: RequestIdSchema + id: RequestIdSchema, + ...RequestSchema.shape }) - .merge(RequestSchema) .strict(); export const isJSONRPCRequest = (value: unknown): value is JSONRPCRequest => JSONRPCRequestSchema.safeParse(value).success; @@ -104,9 +94,9 @@ export const isJSONRPCRequest = (value: unknown): value is JSONRPCRequest => JSO */ export const JSONRPCNotificationSchema = z .object({ - jsonrpc: z.literal(JSONRPC_VERSION) + jsonrpc: z.literal(JSONRPC_VERSION), + ...NotificationSchema.shape }) - .merge(NotificationSchema) .strict(); export const isJSONRPCNotification = (value: unknown): value is JSONRPCNotification => JSONRPCNotificationSchema.safeParse(value).success; @@ -267,12 +257,14 @@ export const BaseMetadataSchema = z.object({ * Describes the name and version of an MCP implementation. */ export const ImplementationSchema = BaseMetadataSchema.extend({ + ...BaseMetadataSchema.shape, + ...IconsSchema.shape, version: z.string(), /** * An optional URL of the website for this implementation. */ websiteUrl: z.string().optional() -}).merge(IconsSchema); +}); const FormElicitationCapabilitySchema = z.intersection( z.object({ @@ -310,7 +302,19 @@ export const ClientCapabilitiesSchema = z.object({ /** * Present if the client supports sampling from an LLM. */ - sampling: AssertObjectSchema.optional(), + sampling: z + .object({ + /** + * Present if the client supports context inclusion via includeContext parameter. + * If not declared, servers SHOULD only use `includeContext: "none"` (or omit it). + */ + context: AssertObjectSchema.optional(), + /** + * Present if the client supports tool use via tools and toolChoice parameters. + */ + tools: AssertObjectSchema.optional() + }) + .optional(), /** * Present if the client supports eliciting user input. */ @@ -454,7 +458,9 @@ export const ProgressSchema = z.object({ message: z.optional(z.string()) }); -export const ProgressNotificationParamsSchema = NotificationsParamsSchema.merge(ProgressSchema).extend({ +export const ProgressNotificationParamsSchema = z.object({ + ...NotificationsParamsSchema.shape, + ...ProgressSchema.shape, /** * The progress token which was given in the initial request, used to associate this notification with the request that is proceeding. */ @@ -547,7 +553,9 @@ export const BlobResourceContentsSchema = ResourceContentsSchema.extend({ /** * A known resource that the server is capable of reading. */ -export const ResourceSchema = BaseMetadataSchema.extend({ +export const ResourceSchema = z.object({ + ...BaseMetadataSchema.shape, + ...IconsSchema.shape, /** * The URI of this resource. */ @@ -569,13 +577,15 @@ export const ResourceSchema = BaseMetadataSchema.extend({ * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.optional(z.object({}).passthrough()) -}).merge(IconsSchema); + _meta: z.optional(z.looseObject({})) +}); /** * A template description for resources available on the server. */ -export const ResourceTemplateSchema = BaseMetadataSchema.extend({ +export const ResourceTemplateSchema = z.object({ + ...BaseMetadataSchema.shape, + ...IconsSchema.shape, /** * A URI template (according to RFC 6570) that can be used to construct resource URIs. */ @@ -597,8 +607,8 @@ export const ResourceTemplateSchema = BaseMetadataSchema.extend({ * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.optional(z.object({}).passthrough()) -}).merge(IconsSchema); + _meta: z.optional(z.looseObject({})) +}); /** * Sent from the client to request a list of resources the server has. @@ -722,7 +732,9 @@ export const PromptArgumentSchema = z.object({ /** * A prompt or prompt template that the server offers. */ -export const PromptSchema = BaseMetadataSchema.extend({ +export const PromptSchema = z.object({ + ...BaseMetadataSchema.shape, + ...IconsSchema.shape, /** * An optional description of what this prompt provides */ @@ -735,8 +747,8 @@ export const PromptSchema = BaseMetadataSchema.extend({ * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) * for notes on _meta usage. */ - _meta: z.optional(z.object({}).passthrough()) -}).merge(IconsSchema); + _meta: z.optional(z.looseObject({})) +}); /** * Sent from the client to request a list of prompts and prompt templates the server has. @@ -832,6 +844,36 @@ export const AudioContentSchema = z.object({ _meta: z.record(z.string(), z.unknown()).optional() }); +/** + * A tool call request from an assistant (LLM). + * Represents the assistant's request to use a tool. + */ +export const ToolUseContentSchema = z + .object({ + type: z.literal('tool_use'), + /** + * The name of the tool to invoke. + * Must match a tool name from the request's tools array. + */ + name: z.string(), + /** + * Unique identifier for this tool call. + * Used to correlate with ToolResultContent in subsequent messages. + */ + id: z.string(), + /** + * Arguments to pass to the tool. + * Must conform to the tool's inputSchema. + */ + input: z.object({}).passthrough(), + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); + /** * The contents of a resource, embedded into a prompt or tool call result. */ @@ -949,7 +991,9 @@ export const ToolAnnotationsSchema = z.object({ /** * Definition for a tool the client can call. */ -export const ToolSchema = BaseMetadataSchema.extend({ +export const ToolSchema = z.object({ + ...BaseMetadataSchema.shape, + ...IconsSchema.shape, /** * A human-readable description of the tool. */ @@ -971,9 +1015,6 @@ export const ToolSchema = BaseMetadataSchema.extend({ type: z.literal('object'), properties: z.record(z.string(), AssertObjectSchema).optional(), required: z.optional(z.array(z.string())), - /** - * Not in the MCP specification, but added to support the Ajv validator while removing .passthrough() which previously allowed additionalProperties to be passed through. - */ additionalProperties: z.optional(z.boolean()) }) .optional(), @@ -987,7 +1028,7 @@ export const ToolSchema = BaseMetadataSchema.extend({ * for notes on _meta usage. */ _meta: z.record(z.string(), z.unknown()).optional() -}).merge(IconsSchema); +}); /** * Sent from the client to request a list of tools the server has. @@ -1159,13 +1200,65 @@ export const ModelPreferencesSchema = z.object({ }); /** - * Describes a message issued to or received from an LLM API. + * Controls tool usage behavior in sampling requests. */ -export const SamplingMessageSchema = z.object({ - role: z.enum(['user', 'assistant']), - content: z.union([TextContentSchema, ImageContentSchema, AudioContentSchema]) +export const ToolChoiceSchema = z.object({ + /** + * Controls when tools are used: + * - "auto": Model decides whether to use tools (default) + * - "required": Model MUST use at least one tool before completing + * - "none": Model MUST NOT use any tools + */ + mode: z.optional(z.enum(['auto', 'required', 'none'])) }); +/** + * The result of a tool execution, provided by the user (server). + * Represents the outcome of invoking a tool requested via ToolUseContent. + */ +export const ToolResultContentSchema = z + .object({ + type: z.literal('tool_result'), + toolUseId: z.string().describe('The unique identifier for the corresponding tool call.'), + content: z.array(ContentBlockSchema).default([]), + structuredContent: z.object({}).passthrough().optional(), + isError: z.optional(z.boolean()), + + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); + +/** + * Content block types allowed in sampling messages. + * This includes text, image, audio, tool use requests, and tool results. + */ +export const SamplingMessageContentBlockSchema = z.discriminatedUnion('type', [ + TextContentSchema, + ImageContentSchema, + AudioContentSchema, + ToolUseContentSchema, + ToolResultContentSchema +]); + +/** + * Describes a message issued to or received from an LLM API. + */ +export const SamplingMessageSchema = z + .object({ + role: z.enum(['user', 'assistant']), + content: z.union([SamplingMessageContentBlockSchema, z.array(SamplingMessageContentBlockSchema)]), + /** + * See [MCP specification](https://github.com/modelcontextprotocol/modelcontextprotocol/blob/47339c03c143bb4ec01a26e721a1b8fe66634ebe/docs/specification/draft/basic/index.mdx#general-fields) + * for notes on _meta usage. + */ + _meta: z.optional(z.object({}).passthrough()) + }) + .passthrough(); + /** * Parameters for a `sampling/createMessage` request. */ @@ -1180,7 +1273,11 @@ export const CreateMessageRequestParamsSchema = BaseRequestParamsSchema.extend({ */ systemPrompt: z.string().optional(), /** - * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. The client MAY ignore this request. + * A request to include context from one or more MCP servers (including the caller), to be attached to the prompt. + * The client MAY ignore this request. + * + * Default is "none". Values "thisServer" and "allServers" are soft-deprecated. Servers SHOULD only use these values if the client + * declares ClientCapabilities.sampling.context. These values may be removed in future spec releases. */ includeContext: z.enum(['none', 'thisServer', 'allServers']).optional(), temperature: z.number().optional(), @@ -1194,7 +1291,18 @@ export const CreateMessageRequestParamsSchema = BaseRequestParamsSchema.extend({ /** * Optional metadata to pass through to the LLM provider. The format of this metadata is provider-specific. */ - metadata: AssertObjectSchema.optional() + metadata: AssertObjectSchema.optional(), + /** + * Tools that the model may use during generation. + * The client MUST return an error if this field is provided but ClientCapabilities.sampling.tools is not declared. + */ + tools: z.optional(z.array(ToolSchema)), + /** + * Controls how the model uses tools. + * The client MUST return an error if this field is provided but ClientCapabilities.sampling.tools is not declared. + * Default is `{ mode: "auto" }`. + */ + toolChoice: z.optional(ToolChoiceSchema) }); /** * A request from the server to sample an LLM via the client. The client has full discretion over which model to select. The client should also inform the user before beginning sampling, to allow them to inspect the request (human in the loop) and decide whether to approve it. @@ -1213,11 +1321,22 @@ export const CreateMessageResultSchema = ResultSchema.extend({ */ model: z.string(), /** - * The reason why sampling stopped. + * The reason why sampling stopped, if known. + * + * Standard values: + * - "endTurn": Natural end of the assistant's turn + * - "stopSequence": A stop sequence was encountered + * - "maxTokens": Maximum token limit was reached + * - "toolUse": The model wants to use one or more tools + * + * This field is an open string to allow for provider-specific stop reasons. */ - stopReason: z.optional(z.enum(['endTurn', 'stopSequence', 'maxTokens']).or(z.string())), + stopReason: z.optional(z.enum(['endTurn', 'stopSequence', 'maxTokens', 'toolUse']).or(z.string())), role: z.enum(['user', 'assistant']), - content: z.discriminatedUnion('type', [TextContentSchema, ImageContentSchema, AudioContentSchema]) + /** + * Response content. May be ToolUseContent if stopReason is "toolUse". + */ + content: z.union([SamplingMessageContentBlockSchema, z.array(SamplingMessageContentBlockSchema)]) }); /* Elicitation */ @@ -1448,7 +1567,7 @@ export const ElicitResultSchema = ResultSchema.extend({ * The submitted form data, only present when action is "accept". * Contains values matching the requested schema. */ - content: z.record(z.union([z.string(), z.number(), z.boolean(), z.array(z.string())])).optional() + content: z.record(z.string(), z.union([z.string(), z.number(), z.boolean(), z.array(z.string())])).optional() }); /* Autocomplete */ @@ -1518,34 +1637,34 @@ export function assertCompleteRequestPrompt(request: CompleteRequest): asserts r if (request.params.ref.type !== 'ref/prompt') { throw new TypeError(`Expected CompleteRequestPrompt, but got ${request.params.ref.type}`); } + void (request as CompleteRequestPrompt); } export function assertCompleteRequestResourceTemplate(request: CompleteRequest): asserts request is CompleteRequestResourceTemplate { if (request.params.ref.type !== 'ref/resource') { throw new TypeError(`Expected CompleteRequestResourceTemplate, but got ${request.params.ref.type}`); } + void (request as CompleteRequestResourceTemplate); } /** * The server's response to a completion/complete request */ export const CompleteResultSchema = ResultSchema.extend({ - completion: z - .object({ - /** - * An array of completion values. Must not exceed 100 items. - */ - values: z.array(z.string()).max(100), - /** - * The total number of completion options available. This can exceed the number of values actually sent in the response. - */ - total: z.optional(z.number().int()), - /** - * Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown. - */ - hasMore: z.optional(z.boolean()) - }) - .passthrough() + completion: z.looseObject({ + /** + * An array of completion values. Must not exceed 100 items. + */ + values: z.array(z.string()).max(100), + /** + * The total number of completion options available. This can exceed the number of values actually sent in the response. + */ + total: z.optional(z.number().int()), + /** + * Indicates whether there are additional completion options beyond those provided in the current response, even if the exact total is unknown. + */ + hasMore: z.optional(z.boolean()) + }) }); /* Roots */ @@ -1699,7 +1818,7 @@ type Flatten = T extends Primitive ? { [K in keyof T]: Flatten } : T; -type Infer = Flatten>; +type Infer = Flatten>; /** * Headers that are compatible with both Node.js and the browser. @@ -1813,6 +1932,8 @@ export type GetPromptRequest = Infer; export type TextContent = Infer; export type ImageContent = Infer; export type AudioContent = Infer; +export type ToolUseContent = Infer; +export type ToolResultContent = Infer; export type EmbeddedResource = Infer; export type ResourceLink = Infer; export type ContentBlock = Infer; @@ -1839,8 +1960,10 @@ export type LoggingMessageNotificationParams = Infer; /* Sampling */ +export type ToolChoice = Infer; export type ModelHint = Infer; export type ModelPreferences = Infer; +export type SamplingMessageContentBlock = Infer; export type SamplingMessage = Infer; export type CreateMessageRequestParams = Infer; export type CreateMessageRequest = Infer; @@ -1879,11 +2002,9 @@ export type PromptReference = Infer; export type CompleteRequestParams = Infer; export type CompleteRequest = Infer; export type CompleteRequestResourceTemplate = ExpandRecursively< - Omit & { params: Omit & { ref: ResourceTemplateReference } } ->; -export type CompleteRequestPrompt = ExpandRecursively< - Omit & { params: Omit & { ref: PromptReference } } + CompleteRequest & { params: CompleteRequestParams & { ref: ResourceTemplateReference } } >; +export type CompleteRequestPrompt = ExpandRecursively; export type CompleteResult = Infer; /* Roots */