diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index 35dcc7ec..806a4791 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -3,13 +3,17 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.github.oshai.kotlinlogging.KotlinLogging +import io.ktor.http.HttpMethod import io.ktor.http.HttpStatusCode import io.ktor.server.application.Application import io.ktor.server.application.ApplicationCall +import io.ktor.server.application.ApplicationCallPipeline import io.ktor.server.application.MissingApplicationPluginException import io.ktor.server.application.install import io.ktor.server.request.ApplicationRequest import io.ktor.server.request.header +import io.ktor.server.request.httpMethod +import io.ktor.server.response.header import io.ktor.server.response.respond import io.ktor.server.routing.Route import io.ktor.server.routing.RoutingContext @@ -114,6 +118,16 @@ private fun Application.mcpStreamableHttp( routing { route(path) { + // Set Mcp-Session-Id on GET responses before Ktor's sse {} commits headers. + intercept(ApplicationCallPipeline.Plugins) { + if (context.request.httpMethod == HttpMethod.Get) { + val sessionId = context.request.header(MCP_SESSION_ID_HEADER) + if (sessionId != null && transportManager.getTransport(sessionId) != null) { + context.response.header(MCP_SESSION_ID_HEADER, sessionId) + } + } + } + sse { val transport = existingStreamableTransport(call, transportManager) ?: return@sse transport.handleRequest(this, call) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index 70948d9c..17e6a282 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -453,8 +453,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat return } - call.appendSseHeaders() - flushSse(sseSession) // flush headers immediately + // SSE headers (Content-Type, Cache-Control, Connection) are already set by the framework's SSE handler + flushSse(sseSession) streamsMapping[STANDALONE_SSE_STREAM_ID] = SessionContext(sseSession, call) maybeSendPrimingEvent(STANDALONE_SSE_STREAM_ID, sseSession, call.request.header(MCP_PROTOCOL_VERSION_HEADER)) sseSession.coroutineContext.job.invokeOnCompletion { @@ -529,8 +529,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } } - call.appendSseHeaders() - flushSse(session) // flush headers immediately + // SSE headers are already set by the framework's SSE handler. + flushSse(session) val streamId = store.replayEventsAfter(lastEventId) { eventId, message -> try { diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index 9cf916af..8b4f6fae 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -2,6 +2,7 @@ package io.modelcontextprotocol.kotlin.sdk.server import io.kotest.matchers.collections.shouldContainAll import io.kotest.matchers.equals.shouldBeEqual +import io.kotest.matchers.nulls.shouldNotBeNull import io.kotest.matchers.shouldBe import io.ktor.client.HttpClient import io.ktor.client.call.body @@ -10,7 +11,9 @@ import io.ktor.client.plugins.logging.Logging import io.ktor.client.request.HttpRequestBuilder import io.ktor.client.request.header import io.ktor.client.request.post +import io.ktor.client.request.prepareGet import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsChannel import io.ktor.http.ContentType import io.ktor.http.HttpHeaders import io.ktor.http.HttpStatusCode @@ -21,6 +24,7 @@ import io.ktor.server.routing.post import io.ktor.server.routing.routing import io.ktor.server.testing.ApplicationTestBuilder import io.ktor.server.testing.testApplication +import io.ktor.utils.io.readUTF8Line import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult import io.modelcontextprotocol.kotlin.sdk.types.Implementation @@ -36,6 +40,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.Method import io.modelcontextprotocol.kotlin.sdk.types.RequestId +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.Tool import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema import io.modelcontextprotocol.kotlin.sdk.types.toJSON @@ -433,6 +438,47 @@ class StreamableHttpServerTransportTest { } } + @Test + fun `GET SSE stream includes Mcp-Session-Id header and stays open`() = testApplication { + val mcpPath = "/mcp" + + application { + mcpStreamableHttp(mcpPath) { + Server( + Implementation("test-server", "1.0.0"), + ServerOptions(capabilities = ServerCapabilities()), + ) + } + } + + val client = createTestClient() + + // Step 1: Initialize session via POST + val initResponse = client.post(mcpPath) { + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + initResponse.status shouldBe HttpStatusCode.OK + val sessionId = assertNotNull(initResponse.headers[MCP_SESSION_ID_HEADER]) + + // Step 2: Open GET SSE stream with session ID + client.prepareGet(mcpPath) { + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(MCP_SESSION_ID_HEADER, sessionId) + header("mcp-protocol-version", LATEST_PROTOCOL_VERSION) + }.execute { response -> + // Verify Mcp-Session-Id is present on the SSE response + response.status shouldBe HttpStatusCode.OK + response.headers[MCP_SESSION_ID_HEADER] shouldBe sessionId + + // Verify the stream is alive by reading at least one line (flush event) + val channel = response.bodyAsChannel() + val firstLine = channel.readUTF8Line() + firstLine.shouldNotBeNull() + channel.isClosedForRead shouldBe false + } + } + private fun ApplicationTestBuilder.configureTransportEndpoint(transport: StreamableHttpServerTransport) { application { routing {