Skip to content

Commit 9eabb54

Browse files
authored
refactor: refactor Streamable HTTP ktor extensions (#562) (#568)
1 parent 66d31d3 commit 9eabb54

2 files changed

Lines changed: 142 additions & 82 deletions

File tree

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt

Lines changed: 120 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
@file:Suppress("TooManyFunctions")
2+
13
package io.modelcontextprotocol.kotlin.sdk.server
24

35
import io.github.oshai.kotlinlogging.KotlinLogging
@@ -20,38 +22,22 @@ import io.ktor.server.sse.SSE
2022
import io.ktor.server.sse.ServerSSESession
2123
import io.ktor.server.sse.sse
2224
import io.ktor.utils.io.KtorDsl
23-
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
2425
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
25-
import kotlinx.atomicfu.AtomicRef
26-
import kotlinx.atomicfu.atomic
27-
import kotlinx.atomicfu.update
28-
import kotlinx.collections.immutable.PersistentMap
29-
import kotlinx.collections.immutable.toPersistentMap
3026
import kotlinx.coroutines.awaitCancellation
3127

3228
private val logger = KotlinLogging.logger {}
3329

34-
internal class TransportManager(transports: Map<String, AbstractTransport> = emptyMap()) {
35-
private val transports: AtomicRef<PersistentMap<String, AbstractTransport>> = atomic(transports.toPersistentMap())
36-
37-
fun hasTransport(sessionId: String): Boolean = transports.value.containsKey(sessionId)
38-
39-
fun getTransport(sessionId: String): AbstractTransport? = transports.value[sessionId]
40-
41-
fun addTransport(sessionId: String, transport: AbstractTransport) {
42-
transports.update { it.put(sessionId, transport) }
43-
}
44-
45-
fun removeTransport(sessionId: String) {
46-
transports.update { it.remove(sessionId) }
47-
}
48-
}
49-
5030
/**
51-
* Registers a server-sent events (SSE) route at the specified path.
31+
* Registers MCP over [Server-Sent Events (SSE) Transport](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse)
32+
* at the specified [path] on this [Route].
33+
*
34+
* **Precondition:** the [SSE] plugin must be installed on the application before calling this function.
35+
* Use [Application.mcp] if you want SSE to be installed automatically.
5236
*
53-
* @param path the URL path to register the route for SSE.
54-
* @param block the block of code that defines the server's behavior for the SSE session.
37+
* @param path the URL path to register the SSE endpoint.
38+
* @param block factory block with access to the [ServerSSESession]
39+
* that creates and returns the [Server] to handle the connection.
40+
* @throws IllegalStateException if the [SSE] plugin is not installed.
5541
*/
5642
@KtorDsl
5743
public fun Route.mcp(path: String, block: ServerSSESession.() -> Server) {
@@ -61,11 +47,14 @@ public fun Route.mcp(path: String, block: ServerSSESession.() -> Server) {
6147
}
6248

6349
/**
64-
* Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
50+
* Registers MCP over [Server-Sent Events (SSE) Transport](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse)
51+
* endpoints on this [Route].
6552
*
6653
* **Precondition:** the [SSE] plugin must be installed on the application before calling this function.
6754
* Use [Application.mcp] if you want SSE to be installed automatically.
6855
*
56+
* @param block factory block with access to the [ServerSSESession]
57+
* that creates and returns the [Server] to handle the connection.
6958
* @throws IllegalStateException if the [SSE] plugin is not installed.
7059
*/
7160
@KtorDsl
@@ -81,7 +70,7 @@ public fun Route.mcp(block: ServerSSESession.() -> Server) {
8170
)
8271
}
8372

84-
val transportManager = TransportManager()
73+
val transportManager = TransportManager<SseServerTransport>()
8574

8675
sse {
8776
mcpSseEndpoint("", transportManager, block)
@@ -92,6 +81,14 @@ public fun Route.mcp(block: ServerSSESession.() -> Server) {
9281
}
9382
}
9483

84+
/**
85+
* Configures the Ktor Application to handle Model Context Protocol (MCP)
86+
* over [Server-Sent Events (SSE) Transport](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse)
87+
* and sets up routing with the provided configuration block.
88+
*
89+
* @param block factory block with access to the [ServerSSESession]
90+
* that creates and returns the [Server] to handle the connection.
91+
*/
9592
@KtorDsl
9693
public fun Application.mcp(block: ServerSSESession.() -> Server) {
9794
install(SSE)
@@ -101,19 +98,14 @@ public fun Application.mcp(block: ServerSSESession.() -> Server) {
10198
}
10299
}
103100

104-
@KtorDsl
105-
@Suppress("LongParameterList")
106-
public fun Application.mcpStreamableHttp(
101+
private fun Application.mcpStreamableHttp(
107102
path: String = "/mcp",
108-
enableDnsRebindingProtection: Boolean = false,
109-
allowedHosts: List<String>? = null,
110-
allowedOrigins: List<String>? = null,
111-
eventStore: EventStore? = null,
103+
configuration: StreamableHttpServerTransport.Configuration,
112104
block: RoutingContext.() -> Server,
113105
) {
114106
install(SSE)
115107

116-
val transportManager = TransportManager()
108+
val transportManager = TransportManager<StreamableHttpServerTransport>()
117109

118110
routing {
119111
route(path) {
@@ -125,13 +117,9 @@ public fun Application.mcpStreamableHttp(
125117
post {
126118
val transport = streamableTransport(
127119
transportManager = transportManager,
128-
enableDnsRebindingProtection = enableDnsRebindingProtection,
129-
allowedHosts = allowedHosts,
130-
allowedOrigins = allowedOrigins,
131-
eventStore = eventStore,
120+
configuration = configuration,
132121
block = block,
133-
)
134-
?: return@post
122+
) ?: return@post
135123

136124
transport.handleRequest(null, call)
137125
}
@@ -144,26 +132,58 @@ public fun Application.mcpStreamableHttp(
144132
}
145133
}
146134

135+
/**
136+
* Configures the Ktor Application to handle Model Context Protocol (MCP)
137+
* over [Streamable HTTP Transport](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http)
138+
*
139+
* Sets up SSE, HTTP POST, and DELETE endpoints at the specified [path].
140+
* Simple request/response pairs are returned as JSON (not SSE streams).
141+
*
142+
* @param path The base path for the MCP Streamable HTTP endpoint. Defaults to "/mcp".
143+
* @param enableDnsRebindingProtection Enables DNS rebinding attack protection for the endpoint. Defaults to false.
144+
* @param allowedHosts A list of hostnames allowed to access the endpoint. If `null`, no restrictions are applied.
145+
* @param allowedOrigins A list of origins allowed to perform cross-origin requests (CORS).
146+
* If `null`, no restrictions are applied.
147+
* @param eventStore An optional [EventStore] instance to enable resumable event stream functionality.
148+
* Allows storing and replaying events.
149+
* @param block factory block with access to the [RoutingContext] (for reading request headers)
150+
* that creates and returns the [Server] to handle the connection.
151+
*/
147152
@KtorDsl
148153
@Suppress("LongParameterList")
149-
public fun Application.mcpStatelessStreamableHttp(
154+
public fun Application.mcpStreamableHttp(
150155
path: String = "/mcp",
151156
enableDnsRebindingProtection: Boolean = false,
152157
allowedHosts: List<String>? = null,
153158
allowedOrigins: List<String>? = null,
154159
eventStore: EventStore? = null,
155160
block: RoutingContext.() -> Server,
161+
) {
162+
mcpStreamableHttp(
163+
path = path,
164+
configuration = StreamableHttpServerTransport.Configuration(
165+
enableDnsRebindingProtection = enableDnsRebindingProtection,
166+
allowedHosts = allowedHosts,
167+
allowedOrigins = allowedOrigins,
168+
eventStore = eventStore,
169+
enableJsonResponse = true,
170+
),
171+
block = block,
172+
)
173+
}
174+
175+
private fun Application.mcpStatelessStreamableHttp(
176+
path: String = "/mcp",
177+
configuration: StreamableHttpServerTransport.Configuration,
178+
block: RoutingContext.() -> Server,
156179
) {
157180
install(SSE)
158181

159182
routing {
160183
route(path) {
161184
post {
162185
mcpStatelessStreamableHttpEndpoint(
163-
enableDnsRebindingProtection = enableDnsRebindingProtection,
164-
allowedHosts = allowedHosts,
165-
allowedOrigins = allowedOrigins,
166-
eventStore = eventStore,
186+
configuration = configuration,
167187
block = block,
168188
)
169189
}
@@ -185,9 +205,47 @@ public fun Application.mcpStatelessStreamableHttp(
185205
}
186206
}
187207

208+
/**
209+
* Configures the Ktor Application to handle Model Context Protocol (MCP)
210+
* over _stateless_ [Streamable HTTP Transport](https://modelcontextprotocol.io/specification/2025-11-25/basic/transports#streamable-http)
211+
*
212+
* Sets up an HTTP POST endpoint at [path]. GET and DELETE requests return 405 Method Not Allowed.
213+
* Simple request/response pairs are returned as JSON (not SSE streams).
214+
*
215+
* @param path The URL path where the server listens for incoming JSON-RPC requests. Defaults to "/mcp".
216+
* @param enableDnsRebindingProtection Determines whether DNS rebinding protection is enabled. Defaults to `false`.
217+
* @param allowedHosts A list of allowed hostnames. If null, host filtering is disabled.
218+
* @param allowedOrigins A list of allowed origins for CORS. If null, origin filtering is disabled.
219+
* @param eventStore An optional [EventStore] implementation to provide resumability and event replay support.
220+
* @param block factory block with access to the [RoutingContext] (for reading request headers)
221+
* that creates and returns the [Server] to handle the connection.
222+
*/
223+
@KtorDsl
224+
@Suppress("LongParameterList")
225+
public fun Application.mcpStatelessStreamableHttp(
226+
path: String = "/mcp",
227+
enableDnsRebindingProtection: Boolean = false,
228+
allowedHosts: List<String>? = null,
229+
allowedOrigins: List<String>? = null,
230+
eventStore: EventStore? = null,
231+
block: RoutingContext.() -> Server,
232+
) {
233+
mcpStatelessStreamableHttp(
234+
path = path,
235+
configuration = StreamableHttpServerTransport.Configuration(
236+
enableDnsRebindingProtection = enableDnsRebindingProtection,
237+
allowedHosts = allowedHosts,
238+
allowedOrigins = allowedOrigins,
239+
eventStore = eventStore,
240+
enableJsonResponse = true,
241+
),
242+
block = block,
243+
)
244+
}
245+
188246
private suspend fun ServerSSESession.mcpSseEndpoint(
189247
postEndpoint: String,
190-
transportManager: TransportManager,
248+
transportManager: TransportManager<SseServerTransport>,
191249
block: ServerSSESession.() -> Server,
192250
) {
193251
val transport = mcpSseTransport(postEndpoint, transportManager)
@@ -208,7 +266,7 @@ private suspend fun ServerSSESession.mcpSseEndpoint(
208266

209267
private fun ServerSSESession.mcpSseTransport(
210268
postEndpoint: String,
211-
transportManager: TransportManager,
269+
transportManager: TransportManager<SseServerTransport>,
212270
): SseServerTransport {
213271
val transport = SseServerTransport(postEndpoint, this)
214272
transportManager.addTransport(transport.sessionId, transport)
@@ -218,20 +276,11 @@ private fun ServerSSESession.mcpSseTransport(
218276
}
219277

220278
private suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint(
221-
enableDnsRebindingProtection: Boolean = false,
222-
allowedHosts: List<String>? = null,
223-
allowedOrigins: List<String>? = null,
224-
eventStore: EventStore? = null,
279+
configuration: StreamableHttpServerTransport.Configuration,
225280
block: RoutingContext.() -> Server,
226281
) {
227282
val transport = StreamableHttpServerTransport(
228-
StreamableHttpServerTransport.Configuration(
229-
enableDnsRebindingProtection = enableDnsRebindingProtection,
230-
allowedHosts = allowedHosts,
231-
allowedOrigins = allowedOrigins,
232-
eventStore = eventStore,
233-
enableJsonResponse = true,
234-
),
283+
configuration,
235284
).also { it.setSessionIdGenerator(null) }
236285

237286
logger.info { "New stateless StreamableHttp connection established without sessionId" }
@@ -244,15 +293,15 @@ private suspend fun RoutingContext.mcpStatelessStreamableHttpEndpoint(
244293
logger.debug { "Server connected to transport without sessionId" }
245294
}
246295

247-
private suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager) {
296+
private suspend fun RoutingContext.mcpPostEndpoint(transportManager: TransportManager<SseServerTransport>) {
248297
val sessionId: String = call.request.queryParameters["sessionId"] ?: run {
249298
call.respond(HttpStatusCode.BadRequest, "sessionId query parameter is not provided")
250299
return
251300
}
252301

253302
logger.debug { "Received message for sessionId: $sessionId" }
254303

255-
val transport = transportManager.getTransport(sessionId) as SseServerTransport?
304+
val transport = transportManager.getTransport(sessionId)
256305
if (transport == null) {
257306
logger.warn { "Session not found for sessionId: $sessionId" }
258307
call.respond(HttpStatusCode.NotFound, "Session not found")
@@ -267,7 +316,7 @@ private fun ApplicationRequest.sessionId(): String? = header(MCP_SESSION_ID_HEAD
267316

268317
private suspend fun existingStreamableTransport(
269318
call: ApplicationCall,
270-
transportManager: TransportManager,
319+
transportManager: TransportManager<StreamableHttpServerTransport>,
271320
): StreamableHttpServerTransport? {
272321
val sessionId = call.request.sessionId()
273322
if (sessionId.isNullOrEmpty()) {
@@ -279,42 +328,31 @@ private suspend fun existingStreamableTransport(
279328
return null
280329
}
281330

282-
val transport = transportManager.getTransport(sessionId) as? StreamableHttpServerTransport
283-
if (transport == null) {
331+
val transport = transportManager.getTransport(sessionId)
332+
return if (transport == null) {
284333
call.reject(
285334
HttpStatusCode.NotFound,
286335
RPCError.ErrorCode.CONNECTION_CLOSED,
287336
"Session not found",
288337
)
289-
return null
338+
null
339+
} else {
340+
transport
290341
}
291-
292-
return transport
293342
}
294343

295344
private suspend fun RoutingContext.streamableTransport(
296-
transportManager: TransportManager,
297-
enableDnsRebindingProtection: Boolean,
298-
allowedHosts: List<String>?,
299-
allowedOrigins: List<String>?,
300-
eventStore: EventStore?,
345+
transportManager: TransportManager<StreamableHttpServerTransport>,
346+
configuration: StreamableHttpServerTransport.Configuration,
301347
block: RoutingContext.() -> Server,
302348
): StreamableHttpServerTransport? {
303349
val sessionId = call.request.sessionId()
304350
if (sessionId != null) {
305-
val transport = transportManager.getTransport(sessionId) as? StreamableHttpServerTransport
351+
val transport = transportManager.getTransport(sessionId)
306352
return transport ?: existingStreamableTransport(call, transportManager)
307353
}
308354

309-
val transport = StreamableHttpServerTransport(
310-
StreamableHttpServerTransport.Configuration(
311-
enableDnsRebindingProtection = enableDnsRebindingProtection,
312-
allowedHosts = allowedHosts,
313-
allowedOrigins = allowedOrigins,
314-
eventStore = eventStore,
315-
enableJsonResponse = true,
316-
),
317-
)
355+
val transport = StreamableHttpServerTransport(configuration)
318356

319357
transport.setOnSessionInitialized { initializedSessionId ->
320358
transportManager.addTransport(initializedSessionId, transport)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package io.modelcontextprotocol.kotlin.sdk.server
2+
3+
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
4+
import kotlinx.atomicfu.AtomicRef
5+
import kotlinx.atomicfu.atomic
6+
import kotlinx.atomicfu.update
7+
import kotlinx.collections.immutable.PersistentMap
8+
import kotlinx.collections.immutable.toPersistentMap
9+
10+
internal class TransportManager<T : AbstractTransport> {
11+
private val transports: AtomicRef<PersistentMap<String, T>> = atomic(emptyMap<String, T>().toPersistentMap())
12+
13+
fun getTransport(sessionId: String): T? = transports.value[sessionId]
14+
15+
fun addTransport(sessionId: String, transport: T) {
16+
transports.update { it.put(sessionId, transport) }
17+
}
18+
19+
fun removeTransport(sessionId: String) {
20+
transports.update { it.remove(sessionId) }
21+
}
22+
}

0 commit comments

Comments
 (0)