Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 206 additions & 0 deletions src/client/streamableHttp.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,70 @@ describe('StreamableHTTPClientTransport', () => {
expect(fetchMock).toHaveBeenCalledTimes(1);
expect(fetchMock.mock.calls[0][1]?.method).toBe('POST');
});

it('should reconnect a POST-initiated stream after receiving a priming event', async () => {
// ARRANGE
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxRetries: 1,
maxReconnectionDelay: 1000,
reconnectionDelayGrowFactor: 1
}
});

const errorSpy = vi.fn();
transport.onerror = errorSpy;

// Create a stream that sends a priming event (with ID) then closes
const streamWithPrimingEvent = new ReadableStream({
start(controller) {
// Send a priming event with an ID - this enables reconnection
controller.enqueue(
new TextEncoder().encode('id: event-123\ndata: {"jsonrpc":"2.0","method":"notifications/message","params":{}}\n\n')
);
// Then close the stream (simulating server disconnect)
controller.close();
}
});

const fetchMock = global.fetch as Mock;
// First call: POST returns streaming response with priming event
fetchMock.mockResolvedValueOnce({
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'text/event-stream' }),
body: streamWithPrimingEvent
});
// Second call: GET reconnection - return 405 to stop further reconnection
fetchMock.mockResolvedValueOnce({
ok: false,
status: 405,
headers: new Headers()
});

const requestMessage: JSONRPCRequest = {
jsonrpc: '2.0',
method: 'long_running_tool',
id: 'request-1',
params: {}
};

// ACT
await transport.start();
await transport.send(requestMessage);
// Wait for stream to process and reconnection to be scheduled
await vi.advanceTimersByTimeAsync(50);

// ASSERT
// THE KEY ASSERTION: Fetch was called TWICE - POST then GET reconnection
expect(fetchMock).toHaveBeenCalledTimes(2);
expect(fetchMock.mock.calls[0][1]?.method).toBe('POST');
expect(fetchMock.mock.calls[1][1]?.method).toBe('GET');
// Verify Last-Event-ID header was sent for reconnection
const reconnectHeaders = fetchMock.mock.calls[1][1]?.headers as Headers;
expect(reconnectHeaders.get('last-event-id')).toBe('event-123');
});
});

it('invalidates all credentials on InvalidClientError during auth', async () => {
Expand Down Expand Up @@ -1102,6 +1166,148 @@ describe('StreamableHTTPClientTransport', () => {
});
});

describe('SSE retry field handling', () => {
beforeEach(() => {
vi.useFakeTimers();
(global.fetch as Mock).mockReset();
});
afterEach(() => vi.useRealTimers());

it('should use server-provided retry value for reconnection delay', async () => {
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
reconnectionOptions: {
initialReconnectionDelay: 100,
maxReconnectionDelay: 5000,
reconnectionDelayGrowFactor: 2,
maxRetries: 3
}
});

// Create a stream that sends a retry field
const encoder = new TextEncoder();
const stream = new ReadableStream({
start(controller) {
// Send SSE event with retry field
const event =
'retry: 3000\nevent: message\nid: evt-1\ndata: {"jsonrpc": "2.0", "method": "notification", "params": {}}\n\n';
controller.enqueue(encoder.encode(event));
// Close stream to trigger reconnection
controller.close();
}
});

const fetchMock = global.fetch as Mock;
fetchMock.mockResolvedValueOnce({
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'text/event-stream' }),
body: stream
});

// Second request for reconnection
fetchMock.mockResolvedValueOnce({
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'text/event-stream' }),
body: new ReadableStream()
});

await transport.start();
await transport['_startOrAuthSse']({});

// Wait for stream to close and reconnection to be scheduled
await vi.advanceTimersByTimeAsync(100);

// Verify the server retry value was captured
const transportInternal = transport as unknown as { _serverRetryMs?: number };
expect(transportInternal._serverRetryMs).toBe(3000);

// Verify the delay calculation uses server retry value
const getDelay = transport['_getNextReconnectionDelay'].bind(transport);
expect(getDelay(0)).toBe(3000); // Should use server value, not 100ms initial
expect(getDelay(5)).toBe(3000); // Should still use server value for any attempt
});

it('should fall back to exponential backoff when no server retry value', () => {
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
reconnectionOptions: {
initialReconnectionDelay: 100,
maxReconnectionDelay: 5000,
reconnectionDelayGrowFactor: 2,
maxRetries: 3
}
});

// Without any SSE stream, _serverRetryMs should be undefined
const transportInternal = transport as unknown as { _serverRetryMs?: number };
expect(transportInternal._serverRetryMs).toBeUndefined();

// Should use exponential backoff
const getDelay = transport['_getNextReconnectionDelay'].bind(transport);
expect(getDelay(0)).toBe(100); // 100 * 2^0
expect(getDelay(1)).toBe(200); // 100 * 2^1
expect(getDelay(2)).toBe(400); // 100 * 2^2
expect(getDelay(10)).toBe(5000); // capped at max
});

it('should reconnect on graceful stream close', async () => {
transport = new StreamableHTTPClientTransport(new URL('http://localhost:1234/mcp'), {
reconnectionOptions: {
initialReconnectionDelay: 10,
maxReconnectionDelay: 1000,
reconnectionDelayGrowFactor: 1,
maxRetries: 1
}
});

// Create a stream that closes gracefully after sending an event with ID
const encoder = new TextEncoder();
const stream = new ReadableStream({
start(controller) {
// Send priming event with ID and retry field
const event = 'id: evt-1\nretry: 100\ndata: \n\n';
controller.enqueue(encoder.encode(event));
// Graceful close
controller.close();
}
});

const fetchMock = global.fetch as Mock;
fetchMock.mockResolvedValueOnce({
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'text/event-stream' }),
body: stream
});

// Second request for reconnection
fetchMock.mockResolvedValueOnce({
ok: true,
status: 200,
headers: new Headers({ 'content-type': 'text/event-stream' }),
body: new ReadableStream()
});

await transport.start();
await transport['_startOrAuthSse']({});

// Wait for stream to process and close
await vi.advanceTimersByTimeAsync(50);

// Wait for reconnection delay (100ms from retry field)
await vi.advanceTimersByTimeAsync(150);

// Should have attempted reconnection
expect(fetchMock).toHaveBeenCalledTimes(2);
expect(fetchMock.mock.calls[0][1]?.method).toBe('GET');
expect(fetchMock.mock.calls[1][1]?.method).toBe('GET');

// Second call should include Last-Event-ID
const secondCallHeaders = fetchMock.mock.calls[1][1]?.headers;
expect(secondCallHeaders?.get('last-event-id')).toBe('evt-1');
});
});

describe('prevent infinite recursion when server returns 401 after successful auth', () => {
it('should throw error when server returns 401 after successful auth', async () => {
const message: JSONRPCMessage = {
Expand Down
58 changes: 54 additions & 4 deletions src/client/streamableHttp.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ export class StreamableHTTPClientTransport implements Transport {
private _protocolVersion?: string;
private _hasCompletedAuthFlow = false; // Circuit breaker: detect auth success followed by immediate 401
private _lastUpscopingHeader?: string; // Track last upscoping header to prevent infinite upscoping.
private _serverRetryMs?: number; // Server-provided retry delay from SSE retry field

onclose?: () => void;
onerror?: (error: Error) => void;
Expand Down Expand Up @@ -203,6 +204,7 @@ export class StreamableHTTPClientTransport implements Transport {

private async _startOrAuthSse(options: StartSSEOptions): Promise<void> {
const { resumptionToken } = options;

try {
// Try to open an initial SSE stream with GET to listen for server messages
// This is optional according to the spec - server may not support it
Expand Down Expand Up @@ -249,7 +251,12 @@ export class StreamableHTTPClientTransport implements Transport {
* @returns Time to wait in milliseconds before next reconnection attempt
*/
private _getNextReconnectionDelay(attempt: number): number {
// Access default values directly, ensuring they're never undefined
// Use server-provided retry value if available
if (this._serverRetryMs !== undefined) {
return this._serverRetryMs;
}

// Fall back to exponential backoff
const initialDelay = this._reconnectionOptions.initialReconnectionDelay;
const growFactor = this._reconnectionOptions.reconnectionDelayGrowFactor;
const maxDelay = this._reconnectionOptions.maxReconnectionDelay;
Expand All @@ -259,7 +266,7 @@ export class StreamableHTTPClientTransport implements Transport {
}

/**
* Schedule a reconnection attempt with exponential backoff
* Schedule a reconnection attempt using server-provided retry interval or backoff
*
* @param lastEventId The ID of the last received event for resumability
* @param attemptCount Current reconnection attempt count for this specific stream
Expand Down Expand Up @@ -295,14 +302,24 @@ export class StreamableHTTPClientTransport implements Transport {
const { onresumptiontoken, replayMessageId } = options;

let lastEventId: string | undefined;
// Track whether we've received a priming event (event with ID)
// Per spec, server SHOULD send a priming event with ID before closing
let hasPrimingEvent = false;
const processStream = async () => {
// this is the closest we can get to trying to catch network errors
// if something happens reader will throw
try {
// Create a pipeline: binary stream -> text decoder -> SSE parser
const reader = stream
.pipeThrough(new TextDecoderStream() as ReadableWritablePair<string, Uint8Array>)
.pipeThrough(new EventSourceParserStream())
.pipeThrough(
new EventSourceParserStream({
onRetry: (retryMs: number) => {
// Capture server-provided retry value for reconnection timing
this._serverRetryMs = retryMs;
}
})
)
.getReader();

while (true) {
Expand All @@ -314,6 +331,8 @@ export class StreamableHTTPClientTransport implements Transport {
// Update last event ID if provided
if (event.id) {
lastEventId = event.id;
// Mark that we've received a priming event - stream is now resumable
hasPrimingEvent = true;
onresumptiontoken?.(event.id);
}

Expand All @@ -329,12 +348,29 @@ export class StreamableHTTPClientTransport implements Transport {
}
}
}

// Handle graceful server-side disconnect
// Server may close connection after sending event ID and retry field
// Reconnect if: already reconnectable (GET stream) OR received a priming event (POST stream with event ID)
const canResume = isReconnectable || hasPrimingEvent;
if (canResume && this._abortController && !this._abortController.signal.aborted) {
this._scheduleReconnection(
{
resumptionToken: lastEventId,
onresumptiontoken,
replayMessageId
},
0
);
}
} catch (error) {
// Handle stream errors - likely a network disconnect
this.onerror?.(new Error(`SSE stream disconnected: ${error}`));

// Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing
if (isReconnectable && this._abortController && !this._abortController.signal.aborted) {
// Reconnect if: already reconnectable (GET stream) OR received a priming event (POST stream with event ID)
const canResume = isReconnectable || hasPrimingEvent;
if (canResume && this._abortController && !this._abortController.signal.aborted) {
// Use the exponential backoff reconnection strategy
try {
this._scheduleReconnection(
Expand Down Expand Up @@ -593,4 +629,18 @@ export class StreamableHTTPClientTransport implements Transport {
get protocolVersion(): string | undefined {
return this._protocolVersion;
}

/**
* Resume an SSE stream from a previous event ID.
* Opens a GET SSE connection with Last-Event-ID header to replay missed events.
*
* @param lastEventId The event ID to resume from
* @param options Optional callback to receive new resumption tokens
*/
async resumeStream(lastEventId: string, options?: { onresumptiontoken?: (token: string) => void }): Promise<void> {
await this._startOrAuthSse({
resumptionToken: lastEventId,
onresumptiontoken: options?.onresumptiontoken
});
}
}
Loading
Loading