Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions kotlin-sdk-client/api/kotlin-sdk-client.api
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/client/StdioClientTranspor
}

public final class io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport {
public fun <init> (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public fun <init> (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;ILkotlin/jvm/functions/Function1;)V
public synthetic fun <init> (Lio/ktor/client/HttpClient;Ljava/lang/String;Lio/modelcontextprotocol/kotlin/sdk/client/ReconnectionOptions;ILkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
Comment on lines +104 to +105
public synthetic fun <init> (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (Lio/ktor/client/HttpClient;Ljava/lang/String;Lkotlin/time/Duration;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getProtocolVersion ()Ljava/lang/String;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ import io.ktor.http.HttpMethod
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.http.isSuccess
import io.ktor.utils.io.charsets.TooLongLineException
import io.ktor.utils.io.readUTF8Line
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractClientTransport
import io.modelcontextprotocol.kotlin.sdk.shared.TooLongFrameException
import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification
Expand Down Expand Up @@ -50,6 +52,14 @@ private const val MCP_SESSION_ID_HEADER = "mcp-session-id"
private const val MCP_PROTOCOL_VERSION_HEADER = "mcp-protocol-version"
private const val MCP_RESUMPTION_TOKEN_HEADER = "Last-Event-ID"

/**
* Default maximum size, in characters, of a single inline SSE event assembled from a POST response.
*
* Mirrors the stdio transport's 16 MiB frame cap: a server that streams `data:` lines without ever
* terminating the event cannot grow the client's buffer without bound.
*/
private const val DEFAULT_MAX_INLINE_SSE_EVENT_SIZE: Int = 16 * 1024 * 1024

/**
* Represents an error from the Streamable HTTP transport.
*
Expand All @@ -75,15 +85,23 @@ private sealed interface ConnectResult {
* @param client Ktor HTTP client used for all requests
* @param url MCP endpoint URL
* @param reconnectionOptions reconnection backoff and retry-limit settings for the SSE stream
* @param maxInlineSseEventSize maximum size, in characters, of a single inline SSE event parsed from a
* POST response; a server that exceeds it (including by never terminating an event) fails the send
* with [io.modelcontextprotocol.kotlin.sdk.shared.TooLongFrameException]. Defaults to 16 MiB.
* @param requestBuilder builder applied to every outgoing HTTP request, e.g. for adding auth headers
*/
Comment on lines +88 to 92
public class StreamableHttpClientTransport(
private val client: HttpClient,
private val url: String,
private val reconnectionOptions: ReconnectionOptions = ReconnectionOptions(),
private val maxInlineSseEventSize: Int = DEFAULT_MAX_INLINE_SSE_EVENT_SIZE,
private val requestBuilder: HttpRequestBuilder.() -> Unit = {},
) : AbstractClientTransport() {

init {
require(maxInlineSseEventSize > 0) { "maxInlineSseEventSize must be greater than 0" }
}

@Deprecated(
"Use constructor with ReconnectionOptions",
replaceWith = ReplaceWith(
Expand All @@ -98,7 +116,12 @@ public class StreamableHttpClientTransport(
url: String,
reconnectionTime: Duration?,
requestBuilder: HttpRequestBuilder.() -> Unit = {},
) : this(client, url, ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds), requestBuilder)
) : this(
client,
url,
ReconnectionOptions(initialReconnectionDelay = reconnectionTime ?: 1.seconds),
requestBuilder = requestBuilder,
)

override val logger: KLogger = KotlinLogging.logger {}

Expand Down Expand Up @@ -458,7 +481,14 @@ public class StreamableHttpClientTransport(
}

while (!channel.isClosedForRead) {
val line = channel.readUTF8Line() ?: break
// Bound each line so a server that streams a line without ever terminating it cannot
// exhaust client memory; readUTF8Line returns null at the end of the stream.
val line = try {
channel.readUTF8Line(maxInlineSseEventSize)
} catch (_: TooLongLineException) {
throw TooLongFrameException(maxInlineSseEventSize.toLong() + 1, maxInlineSseEventSize)
}
Comment on lines +486 to +490
if (line == null) break
if (line.isEmpty()) {
dispatch(id = id, eventName = eventName, data = sb.toString())
// reset
Expand All @@ -472,7 +502,13 @@ public class StreamableHttpClientTransport(

line.startsWith("event:") -> eventName = line.substringAfter("event:").trim()

line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim())
line.startsWith("data:") -> {
sb.append(line.substringAfter("data:").trim())
// Cap an event assembled from many data: lines that never sees a terminating blank line.
if (sb.length > maxInlineSseEventSize) {
throw TooLongFrameException(sb.length.toLong(), maxInlineSseEventSize)
}
}

line.startsWith("retry:") -> line.substringAfter("retry:").trim().toLongOrNull()?.let {
localServerRetryDelay = it.milliseconds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.ktor.http.headersOf
import io.ktor.utils.io.ByteReadChannel
import io.modelcontextprotocol.kotlin.sdk.client.Client
import io.modelcontextprotocol.kotlin.sdk.client.StreamableHttpClientTransport
import io.modelcontextprotocol.kotlin.sdk.shared.TooLongFrameException
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification
Expand All @@ -43,19 +44,27 @@ import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNull
import kotlin.test.assertTrue
import kotlin.time.Duration.Companion.milliseconds
import kotlin.time.Duration.Companion.seconds

class StreamableHttpClientTransportTest {

private fun createTransport(handler: MockRequestHandler): StreamableHttpClientTransport {
private fun createTransport(
maxInlineSseEventSize: Int = 16 * 1024 * 1024,
handler: MockRequestHandler,
): StreamableHttpClientTransport {
val mockEngine = MockEngine(handler)
val httpClient = HttpClient(mockEngine) {
install(SSE) {
reconnectionTime = 1.seconds
}
}

return StreamableHttpClientTransport(httpClient, url = "http://localhost:8080/mcp")
return StreamableHttpClientTransport(
httpClient,
url = "http://localhost:8080/mcp",
maxInlineSseEventSize = maxInlineSseEventSize,
)
}

private fun buildSseMessage(id: String, method: String, params: String): String = buildString {
Expand Down Expand Up @@ -551,6 +560,97 @@ class StreamableHttpClientTransportTest {
transport.close()
}

@Test
fun testInlineSseRejectsEventExceedingMaxSize() = runTest {
// A malicious server streams an endless single event: many `data:` lines that accumulate
// past the cap and never send the blank-line terminator that would flush the buffer.
val maxInlineSseEventSize = 64
val transport = createTransport(maxInlineSseEventSize) { request ->
if (request.method == HttpMethod.Post) {
val sseContent = buildString {
appendLine("event: message")
// 20 lines × 16 chars = 320 chars of accumulated data, no terminating blank line.
repeat(20) { appendLine("data: ${"A".repeat(16)}") }
}
respond(
content = ByteReadChannel(sseContent),
status = HttpStatusCode.OK,
headers = headersOf(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()),
)
} else {
respond("", HttpStatusCode.OK)
}
}

val receivedMessages = mutableListOf<JSONRPCMessage>()
val receivedErrors = mutableListOf<Throwable>()
transport.onMessage { receivedMessages.add(it) }
transport.onError { receivedErrors.add(it) }
transport.start()

val error = assertFailsWith<McpException> {
transport.send(JSONRPCRequest(id = "req-1", method = "test", params = buildJsonObject { }))
}

error.cause.shouldBeInstanceOf<TooLongFrameException>()
receivedErrors.filterIsInstance<TooLongFrameException>() shouldHaveSize 1
receivedMessages shouldHaveSize 0
transport.close()
}

@Test
fun testInlineSseEventExactlyAtMaxSizeIsAccepted() = runTest {
// An event whose assembled data length equals the cap must still be accepted and dispatched:
// the guard rejects only sizes strictly greater than the cap (parity with ReadBuffer).
val part1 = """{"jsonrpc":"2.0","""
val part2 = """"method":"notifications/tools/list_changed"}"""
val maxInlineSseEventSize = (part1 + part2).length

val transport = createTransport(maxInlineSseEventSize) { request ->
if (request.method == HttpMethod.Post) {
val sseContent = buildString {
appendLine("event: message")
appendLine("data: $part1")
appendLine("data: $part2")
appendLine()
}
respond(
content = ByteReadChannel(sseContent),
status = HttpStatusCode.OK,
headers = headersOf(HttpHeaders.ContentType, ContentType.Text.EventStream.toString()),
)
} else {
respond("", HttpStatusCode.OK)
}
}

val receivedMessages = mutableListOf<JSONRPCMessage>()
val receivedErrors = mutableListOf<Throwable>()
val messageReceived = CompletableDeferred<Unit>()
transport.onMessage {
receivedMessages.add(it)
if (!messageReceived.isCompleted) messageReceived.complete(Unit)
}
transport.onError { receivedErrors.add(it) }
transport.start()

transport.send(JSONRPCRequest(id = "req-1", method = "test", params = buildJsonObject { }))

eventually { messageReceived.await() }

receivedMessages shouldHaveSize 1
(receivedMessages[0] as JSONRPCNotification).method shouldBe "notifications/tools/list_changed"
receivedErrors shouldHaveSize 0
transport.close()
}

@Test
fun testNonPositiveMaxInlineSseEventSizeThrows() {
assertFailsWith<IllegalArgumentException> {
createTransport(maxInlineSseEventSize = 0) { respond("", HttpStatusCode.OK) }
}
}

@Test
fun testInlineSSEInResponse() = runTest {
val transport = createTransport { request ->
Expand Down Expand Up @@ -716,7 +816,7 @@ class StreamableHttpClientTransportTest {

eventually {
while (receivedMessages.isEmpty()) {
delay(10)
delay(10.milliseconds)
}
}

Expand Down
Loading