diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/streamablehttp/StreamableHttpSseReconnectTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/streamablehttp/StreamableHttpSseReconnectTest.kt new file mode 100644 index 00000000..751f8dd5 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/streamablehttp/StreamableHttpSseReconnectTest.kt @@ -0,0 +1,171 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.streamablehttp + +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.shouldBe +import io.ktor.client.HttpClient +import io.ktor.client.plugins.sse.SSE +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.prepareGet +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsChannel +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.utils.io.readUTF8Line +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import io.modelcontextprotocol.kotlin.test.utils.actualPort +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.runBlocking +import kotlinx.serialization.json.Json +import kotlin.test.Test +import kotlin.test.assertNotNull +import io.ktor.client.engine.cio.CIO as ClientCIO + +private const val SESSION_ID_HEADER = "mcp-session-id" +private const val PROTOCOL_VERSION_HEADER = "mcp-protocol-version" + +/** + * Integration tests for GET SSE stream reconnection using a real embedded CIO server. + * + * Verifies that the transport correctly evicts stale STANDALONE_SSE_STREAM_ID + * entries when a client reconnects after a disconnect, rather than silently + * rejecting the new stream. + */ +class StreamableHttpSseReconnectTest : AbstractStreamableHttpIntegrationTest() { + + /** + * Verifies that after a GET SSE stream disconnects and the client + * immediately reconnects, the server evicts the stale stream mapping + * and allows the new stream to succeed. + */ + @Test + fun `GET SSE reconnect after disconnect should succeed`(): Unit = runBlocking(Dispatchers.IO) { + var server: StreamableHttpTestServer? = null + var httpClient: HttpClient? = null + + try { + server = initTestServer("reconnect-test") + val port = server.ktorServer.actualPort() + val mcpUrl = "http://$URL:$port/mcp" + + httpClient = HttpClient(ClientCIO) { install(SSE) } + + // Step 1: Initialize session via POST + val initResponse = httpClient.post(mcpUrl) { + contentType(ContentType.Application.Json) + header( + HttpHeaders.Accept, + "${ContentType.Application.Json}, ${ContentType.Text.EventStream}", + ) + setBody(Json.encodeToString(buildInitPayload())) + } + initResponse.status shouldBe HttpStatusCode.OK + val sessionId = assertNotNull(initResponse.headers[SESSION_ID_HEADER]) + + // Step 2: Open GET SSE stream, consume the flush event, then disconnect + httpClient.prepareGet(mcpUrl) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(SESSION_ID_HEADER, sessionId) + header(PROTOCOL_VERSION_HEADER, LATEST_PROTOCOL_VERSION) + }.execute { response -> + response.status shouldBe HttpStatusCode.OK + response.bodyAsChannel().readUTF8Line() + } + + // Step 3: Immediately reconnect. The transport detects that the + // previous stream's coroutine is no longer active and evicts the + // stale mapping, allowing the new stream to succeed. + httpClient.prepareGet(mcpUrl) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(SESSION_ID_HEADER, sessionId) + header(PROTOCOL_VERSION_HEADER, LATEST_PROTOCOL_VERSION) + }.execute { response -> + response.status shouldBe HttpStatusCode.OK + response.headers[SESSION_ID_HEADER] shouldBe sessionId + + val channel = response.bodyAsChannel() + val firstLine = channel.readUTF8Line() + firstLine.shouldNotBeNull() + channel.isClosedForRead shouldBe false + } + } finally { + httpClient?.close() + server?.ktorServer?.stopSuspend(1000, 2000) + } + } + + /** + * Verifies that a second concurrent GET SSE stream on the same session + * closes the old stream and takes over. The new stream should be live. + */ + @Test + fun `concurrent GET SSE stream closes old stream and takes over`(): Unit = runBlocking(Dispatchers.IO) { + var server: StreamableHttpTestServer? = null + var httpClient: HttpClient? = null + + try { + server = initTestServer("takeover-test") + val port = server.ktorServer.actualPort() + val mcpUrl = "http://$URL:$port/mcp" + + httpClient = HttpClient(ClientCIO) { install(SSE) } + + // Step 1: Initialize session via POST + val initResponse = httpClient.post(mcpUrl) { + contentType(ContentType.Application.Json) + header( + HttpHeaders.Accept, + "${ContentType.Application.Json}, ${ContentType.Text.EventStream}", + ) + setBody(Json.encodeToString(buildInitPayload())) + } + initResponse.status shouldBe HttpStatusCode.OK + val sessionId = assertNotNull(initResponse.headers[SESSION_ID_HEADER]) + + // Step 2: Open first GET SSE stream and keep it open + httpClient.prepareGet(mcpUrl) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(SESSION_ID_HEADER, sessionId) + header(PROTOCOL_VERSION_HEADER, LATEST_PROTOCOL_VERSION) + }.execute { firstResponse -> + firstResponse.status shouldBe HttpStatusCode.OK + firstResponse.bodyAsChannel().readUTF8Line() + + // Step 3: Open a second GET — closes old stream, new one takes over + httpClient.prepareGet(mcpUrl) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(SESSION_ID_HEADER, sessionId) + header(PROTOCOL_VERSION_HEADER, LATEST_PROTOCOL_VERSION) + }.execute { secondResponse -> + secondResponse.status shouldBe HttpStatusCode.OK + secondResponse.headers[SESSION_ID_HEADER] shouldBe sessionId + + // New stream is alive + val secondChannel = secondResponse.bodyAsChannel() + val firstLine = secondChannel.readUTF8Line() + firstLine.shouldNotBeNull() + secondChannel.isClosedForRead shouldBe false + } + } + } finally { + httpClient?.close() + server?.ktorServer?.stopSuspend(1000, 2000) + } + } + + private fun buildInitPayload(): JSONRPCRequest = InitializeRequest( + InitializeRequestParams( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ClientCapabilities(), + clientInfo = Implementation(name = "reconnect-test-client", version = "1.0.0"), + ), + ).toJSON() +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 497005a4..99174271 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -464,27 +464,43 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } } - if (STANDALONE_SSE_STREAM_ID in streamsMapping) { - call.reject( - HttpStatusCode.Conflict, - RPCError.ErrorCode.CONNECTION_CLOSED, - "Conflict: Only one SSE stream is allowed per session", - ) - return + streamsMapping[STANDALONE_SSE_STREAM_ID]?.let { existingContext -> + // Close the previous SSE session. If the stream is already dead this + // is a no-op. If it is still alive, closing it cancels the coroutine + // blocked in awaitCancellation(), which triggers the identity-guarded + // finally block to remove the mapping. + try { + existingContext.session?.close() + } catch (_: CancellationException) { + throw CancellationException("Cancelled while closing previous SSE stream") + } catch (_: Exception) { + // Ignore — the old stream may already be closed. + } + // After closing, give the old coroutine's finally block a chance to + // remove the mapping. If the entry is still present (race edge case), + // evict it — the old session is closed either way. + streamsMapping.remove(STANDALONE_SSE_STREAM_ID) } // SSE headers (Content-Type, Cache-Control, Connection) are already set by the framework's SSE handler flushSse(sseSession) - streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(sseSession, call) + val newContext = SessionContext(sseSession, call) + streamsMapping[STANDALONE_SSE_STREAM_ID] = newContext val clientProtocolVersion = call.request.header(MCP_PROTOCOL_VERSION_HEADER) ?: DEFAULT_NEGOTIATED_PROTOCOL_VERSION maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession, clientProtocolVersion) - sseSession.coroutineContext.job.invokeOnCompletion { - streamsMapping.remove(STANDALONE_SSE_STREAM_ID) - } // Keep the SSE connection open until the client disconnects or the transport is closed. - // Without this, the Ktor sse{} handler returns immediately, closing the stream. - awaitCancellation() + // Cleanup uses try/finally (runs during cancellation propagation) instead of + // invokeOnCompletion (runs after job completion) to minimize the window between + // disconnect and mapping removal. Identity check ensures only this stream's entry + // is removed — not a replacement that arrived in the meantime. + try { + awaitCancellation() + } finally { + if (streamsMapping[STANDALONE_SSE_STREAM_ID] === newContext) { + streamsMapping.remove(STANDALONE_SSE_STREAM_ID) + } + } } /** Handles an HTTP DELETE request by closing the session and the transport. */ @@ -725,10 +741,17 @@ public class StreamableHttpServerTransport(private val configuration: Configurat try { session?.send(event = "message", id = eventId, data = McpJson.encodeToString(message)) } catch (e: CancellationException) { - streamsMapping.remove(streamId) + // Identity-based removal: only evict this stream's entry, not a replacement's. + val current = streamsMapping[streamId] + if (current?.session === session) { + streamsMapping.remove(streamId) + } throw e } catch (_: Exception) { - streamsMapping.remove(streamId) + val current = streamsMapping[streamId] + if (current?.session === session) { + streamsMapping.remove(streamId) + } } } diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index d00e10ec..8b413d8a 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -460,6 +460,107 @@ class StreamableHttpServerTransportTest { } } + @Test + fun `second concurrent GET SSE closes old stream and takes over`() = testApplication { + val mcpPath = "/mcp" + + application { + mcpStreamableHttp(mcpPath) { + Server( + Implementation("test-server", "1.0.0"), + ServerOptions(capabilities = ServerCapabilities()), + ) + } + } + + val client = createTestClient() + + // Step 1: Initialize session via POST + val initResponse = client.post(mcpPath) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + initResponse.status shouldBe HttpStatusCode.OK + val sessionId = assertNotNull(initResponse.headers[MCP_SESSION_ID_HEADER]) + + // Step 2: Open first GET SSE stream + client.prepareGet(mcpPath) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(MCP_SESSION_ID_HEADER, sessionId) + header("mcp-protocol-version", LATEST_PROTOCOL_VERSION) + }.execute { firstResponse -> + firstResponse.status shouldBe HttpStatusCode.OK + firstResponse.bodyAsChannel().readUTF8Line() + + // Step 3: Open a second GET — the transport closes the old session + // and the new stream takes over. + client.prepareGet(mcpPath) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(MCP_SESSION_ID_HEADER, sessionId) + header("mcp-protocol-version", LATEST_PROTOCOL_VERSION) + }.execute { secondResponse -> + secondResponse.status shouldBe HttpStatusCode.OK + secondResponse.headers[MCP_SESSION_ID_HEADER] shouldBe sessionId + + // New stream is alive + val secondChannel = secondResponse.bodyAsChannel() + val firstLine = secondChannel.readUTF8Line() + firstLine.shouldNotBeNull() + secondChannel.isClosedForRead shouldBe false + } + } + } + + @Test + fun `GET SSE reconnect after previous stream disconnects should succeed`() = testApplication { + val mcpPath = "/mcp" + + application { + mcpStreamableHttp(mcpPath) { + Server( + Implementation("test-server", "1.0.0"), + ServerOptions(capabilities = ServerCapabilities()), + ) + } + } + + val client = createTestClient() + + // Step 1: Initialize session via POST + val initResponse = client.post(mcpPath) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + initResponse.status shouldBe HttpStatusCode.OK + val sessionId = assertNotNull(initResponse.headers[MCP_SESSION_ID_HEADER]) + + // Step 2: Open and then close a GET SSE stream + client.prepareGet(mcpPath) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(MCP_SESSION_ID_HEADER, sessionId) + header("mcp-protocol-version", LATEST_PROTOCOL_VERSION) + }.execute { response -> + response.status shouldBe HttpStatusCode.OK + response.bodyAsChannel().readUTF8Line() + } + + // Step 3: Immediately reconnect — the transport should close the stale + // stream and allow the new one. + client.prepareGet(mcpPath) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(MCP_SESSION_ID_HEADER, sessionId) + header("mcp-protocol-version", LATEST_PROTOCOL_VERSION) + }.execute { response -> + response.status shouldBe HttpStatusCode.OK + response.headers[MCP_SESSION_ID_HEADER] shouldBe sessionId + + val channel = response.bodyAsChannel() + val firstLine = channel.readUTF8Line() + firstLine.shouldNotBeNull() + channel.isClosedForRead shouldBe false + } + } + @Test fun `GET SSE stream includes Mcp-Session-Id header and stays open`() = testApplication { val mcpPath = "/mcp"