|
1 | | -import type { JSONRPCMessage } from '@modelcontextprotocol/core'; |
| 1 | +import type { IncomingMessage, Server } from 'node:http'; |
| 2 | +import { createServer } from 'node:http'; |
| 3 | + |
| 4 | +import type { JSONRPCMessage, OAuthClientInformation, OAuthClientMetadata, OAuthTokens } from '@modelcontextprotocol/core'; |
2 | 5 | import { SdkError, SdkErrorCode } from '@modelcontextprotocol/core'; |
| 6 | +import { listenOnRandomPort } from '@modelcontextprotocol/test-helpers'; |
3 | 7 | import type { Mock } from 'vitest'; |
4 | 8 |
|
5 | | -import type { AuthProvider } from '../../src/client/auth.js'; |
| 9 | +import type { AuthProvider, OAuthClientProvider } from '../../src/client/auth.js'; |
6 | 10 | import { UnauthorizedError } from '../../src/client/auth.js'; |
7 | 11 | import { StreamableHTTPClientTransport } from '../../src/client/streamableHttp.js'; |
8 | 12 |
|
@@ -180,3 +184,132 @@ describe('StreamableHTTPClientTransport with AuthProvider', () => { |
180 | 184 | expect(retryInit.headers.get('Authorization')).toBe('Bearer new-token'); |
181 | 185 | }); |
182 | 186 | }); |
| 187 | + |
| 188 | +describe('AuthProvider integration — both modes against a real server', () => { |
| 189 | + let server: Server; |
| 190 | + let serverUrl: URL; |
| 191 | + let capturedRequests: IncomingMessage[]; |
| 192 | + let transport: StreamableHTTPClientTransport; |
| 193 | + |
| 194 | + const message: JSONRPCMessage = { jsonrpc: '2.0', method: 'ping', params: {}, id: '1' }; |
| 195 | + |
| 196 | + beforeEach(async () => { |
| 197 | + capturedRequests = []; |
| 198 | + server = createServer((req, res) => { |
| 199 | + capturedRequests.push(req); |
| 200 | + if (req.method === 'POST') { |
| 201 | + // Consume body then respond 202 Accepted |
| 202 | + req.on('data', () => {}); |
| 203 | + req.on('end', () => res.writeHead(202).end()); |
| 204 | + } else { |
| 205 | + // GET SSE — reject so the transport skips it |
| 206 | + res.writeHead(405).end(); |
| 207 | + } |
| 208 | + }); |
| 209 | + serverUrl = await listenOnRandomPort(server); |
| 210 | + }); |
| 211 | + |
| 212 | + afterEach(async () => { |
| 213 | + await transport?.close().catch(() => {}); |
| 214 | + await new Promise<void>(resolve => server.close(() => resolve())); |
| 215 | + }); |
| 216 | + |
| 217 | + it('MODE A: minimal AuthProvider { token } sends Authorization header', async () => { |
| 218 | + const authProvider: AuthProvider = { token: async () => 'mode-a-token' }; |
| 219 | + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); |
| 220 | + |
| 221 | + await transport.send(message); |
| 222 | + |
| 223 | + expect(capturedRequests).toHaveLength(1); |
| 224 | + expect(capturedRequests[0]!.headers.authorization).toBe('Bearer mode-a-token'); |
| 225 | + }); |
| 226 | + |
| 227 | + it('MODE A: onUnauthorized signals and throws — caller sees the error', async () => { |
| 228 | + const uiSignal = vi.fn(); |
| 229 | + const authProvider: AuthProvider = { |
| 230 | + token: async () => 'rejected-token', |
| 231 | + onUnauthorized: async () => { |
| 232 | + uiSignal('show-reauth-prompt'); |
| 233 | + throw new UnauthorizedError('user action required'); |
| 234 | + } |
| 235 | + }; |
| 236 | + |
| 237 | + // Server that rejects with 401 |
| 238 | + await new Promise<void>(resolve => server.close(() => resolve())); |
| 239 | + server = createServer((req, res) => { |
| 240 | + capturedRequests.push(req); |
| 241 | + req.on('data', () => {}); |
| 242 | + req.on('end', () => res.writeHead(401).end()); |
| 243 | + }); |
| 244 | + serverUrl = await listenOnRandomPort(server); |
| 245 | + |
| 246 | + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider }); |
| 247 | + |
| 248 | + await expect(transport.send(message)).rejects.toThrow('user action required'); |
| 249 | + expect(uiSignal).toHaveBeenCalledWith('show-reauth-prompt'); |
| 250 | + }); |
| 251 | + |
| 252 | + it('MODE B: OAuthClientProvider is adapted — tokens() becomes token() on the wire', async () => { |
| 253 | + // Minimal OAuthClientProvider — the transport should adapt it via adaptOAuthProvider |
| 254 | + const oauthProvider: OAuthClientProvider = { |
| 255 | + get redirectUrl() { |
| 256 | + return undefined; |
| 257 | + }, |
| 258 | + get clientMetadata(): OAuthClientMetadata { |
| 259 | + return { redirect_uris: [], grant_types: ['client_credentials'] }; |
| 260 | + }, |
| 261 | + clientInformation(): OAuthClientInformation { |
| 262 | + return { client_id: 'test-client' }; |
| 263 | + }, |
| 264 | + tokens(): OAuthTokens { |
| 265 | + return { access_token: 'mode-b-oauth-token', token_type: 'bearer' }; |
| 266 | + }, |
| 267 | + saveTokens() {}, |
| 268 | + redirectToAuthorization() { |
| 269 | + throw new Error('not used'); |
| 270 | + }, |
| 271 | + saveCodeVerifier() {}, |
| 272 | + codeVerifier() { |
| 273 | + throw new Error('not used'); |
| 274 | + } |
| 275 | + }; |
| 276 | + |
| 277 | + transport = new StreamableHTTPClientTransport(serverUrl, { authProvider: oauthProvider }); |
| 278 | + |
| 279 | + await transport.send(message); |
| 280 | + |
| 281 | + expect(capturedRequests).toHaveLength(1); |
| 282 | + expect(capturedRequests[0]!.headers.authorization).toBe('Bearer mode-b-oauth-token'); |
| 283 | + }); |
| 284 | + |
| 285 | + it('both modes use the same option slot and same send() call', async () => { |
| 286 | + // Mode A |
| 287 | + const transportA = new StreamableHTTPClientTransport(serverUrl, { |
| 288 | + authProvider: { token: async () => 'a-token' } |
| 289 | + }); |
| 290 | + await transportA.send(message); |
| 291 | + await transportA.close(); |
| 292 | + |
| 293 | + // Mode B — same constructor, same option name, different shape |
| 294 | + const transportB = new StreamableHTTPClientTransport(serverUrl, { |
| 295 | + authProvider: { |
| 296 | + get redirectUrl() { |
| 297 | + return undefined; |
| 298 | + }, |
| 299 | + get clientMetadata(): OAuthClientMetadata { |
| 300 | + return { redirect_uris: [] }; |
| 301 | + }, |
| 302 | + clientInformation: () => ({ client_id: 'x' }), |
| 303 | + tokens: () => ({ access_token: 'b-token', token_type: 'bearer' }), |
| 304 | + saveTokens() {}, |
| 305 | + redirectToAuthorization() {}, |
| 306 | + saveCodeVerifier() {}, |
| 307 | + codeVerifier: () => '' |
| 308 | + } satisfies OAuthClientProvider |
| 309 | + }); |
| 310 | + await transportB.send(message); |
| 311 | + await transportB.close(); |
| 312 | + |
| 313 | + expect(capturedRequests.map(r => r.headers.authorization)).toEqual(['Bearer a-token', 'Bearer b-token']); |
| 314 | + }); |
| 315 | +}); |
0 commit comments