diff --git a/AGENTS.md b/AGENTS.md index 1646c55a4..e3a660e27 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -77,6 +77,9 @@ Follow these rules to keep changes safe, comprehensible, and easy to maintain. - **Prioritize test readability** - Avoid creating too many test methods; use parametrized tests when testing multiple similar scenarios - When running tests on Kotlin Multiplatform projects, run JVM tests only unless asked for other platforms +- **Concurrency in Tests**: Always use thread-safe collections (e.g., `Mutex`-protected lists or `Channel`) when + collecting messages from transports that process messages concurrently in the background (like those inheriting from + `AbstractTransport`). Using non-thread-safe `MutableList` will lead to flaky tests or missing messages. ### Test Framework Stack @@ -93,9 +96,12 @@ Follow these rules to keep changes safe, comprehensible, and easy to maintain. - **Ktor MockEngine**: For HTTP client mocking (`io.ktor:ktor-client-mock`) - **Java tests**: Use JUnit5, Mockito, AssertJ core - **Serialization test utilities** (`io.modelcontextprotocol.kotlin.test.utils`): - - `verifySerialization(value, json, expectedJson)` — serializes, asserts match, round-trips back; use for most serialization tests - - `verifyDeserialization(json, payload)` — deserializes from JSON, re-serializes, asserts match; returns the object for further assertions - - Always test both empty/null/omitted and non-null cases for nullable fields; `McpJson` has `explicitNulls = false` so null properties must be absent from JSON, not `null` + - `verifySerialization(value, json, expectedJson)` — serializes, asserts match, round-trips back; use for most + serialization tests + - `verifyDeserialization(json, payload)` — deserializes from JSON, re-serializes, asserts match; returns the object + for further assertions + - Always test both empty/null/omitted and non-null cases for nullable fields; `McpJson` has `explicitNulls = false` + so null properties must be absent from JSON, not `null` ### Kotest Patterns @@ -169,6 +175,8 @@ prop.shouldNotBeNull { - Use Kotlinx Serialization with explicit `@Serializable` annotations - JSON config is defined in `jsonUtils.kt` as `McpJson` — use it consistently - Register custom serializers in companion objects +- **SSE Data Concatenation**: When parsing Server-Sent Events (SSE) data, always ensure that multiple `data:` lines are + concatenated with a newline (`\n`) separator, as per the SSE specification. ### Error Handling diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index 0726fa829..5da892f87 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -61,32 +61,33 @@ class ClientTest { @Test fun `should initialize with matching protocol version`() = runTest { var initialised = false - val clientTransport = object : AbstractTransport() { - override suspend fun start() {} - - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - if (message !is JSONRPCRequest) return - initialised = true - val result = InitializeResult( - protocolVersion = LATEST_PROTOCOL_VERSION, - capabilities = ServerCapabilities(), - serverInfo = Implementation( - name = "test", - version = "1.0", - ), - ) - - val response = JSONRPCResponse( - id = message.id, - result = result, - ) - - _onMessage.invoke(response) - } + val clientTransport = + object : AbstractTransport(backgroundScope.coroutineContext, backgroundScope.coroutineContext) { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + if (message !is JSONRPCRequest) return + initialised = true + val result = InitializeResult( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ServerCapabilities(), + serverInfo = Implementation( + name = "test", + version = "1.0", + ), + ) + + val response = JSONRPCResponse( + id = message.id, + result = result, + ) + + handleMessage(response) + } - override suspend fun close() { + override suspend fun close() { + } } - } val client = Client( clientInfo = Implementation( @@ -107,32 +108,33 @@ class ClientTest { @Test fun `should initialize with supported older protocol version`() = runTest { val oldVersion = SUPPORTED_PROTOCOL_VERSIONS[1] - val clientTransport = object : AbstractTransport() { - override suspend fun start() {} - - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - if (message !is JSONRPCRequest) return - check(message.method == Method.Defined.Initialize.value) - - val result = InitializeResult( - protocolVersion = oldVersion, - capabilities = ServerCapabilities(), - serverInfo = Implementation( - name = "test", - version = "1.0", - ), - ) - - val response = JSONRPCResponse( - id = message.id, - result = result, - ) - _onMessage.invoke(response) - } + val clientTransport = + object : AbstractTransport(backgroundScope.coroutineContext, backgroundScope.coroutineContext) { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + + val result = InitializeResult( + protocolVersion = oldVersion, + capabilities = ServerCapabilities(), + serverInfo = Implementation( + name = "test", + version = "1.0", + ), + ) + + val response = JSONRPCResponse( + id = message.id, + result = result, + ) + handleMessage(response) + } - override suspend fun close() { + override suspend fun close() { + } } - } val client = Client( clientInfo = Implementation( @@ -156,34 +158,35 @@ class ClientTest { @Test fun `should reject unsupported protocol version`() = runTest { var closed = false - val clientTransport = object : AbstractTransport() { - override suspend fun start() {} - - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - if (message !is JSONRPCRequest) return - check(message.method == Method.Defined.Initialize.value) - - val result = InitializeResult( - protocolVersion = "invalid-version", - capabilities = ServerCapabilities(), - serverInfo = Implementation( - name = "test", - version = "1.0", - ), - ) - - val response = JSONRPCResponse( - id = message.id, - result = result, - ) - - _onMessage.invoke(response) - } + val clientTransport = + object : AbstractTransport(backgroundScope.coroutineContext, backgroundScope.coroutineContext) { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + + val result = InitializeResult( + protocolVersion = "invalid-version", + capabilities = ServerCapabilities(), + serverInfo = Implementation( + name = "test", + version = "1.0", + ), + ) + + val response = JSONRPCResponse( + id = message.id, + result = result, + ) + + handleMessage(response) + } - override suspend fun close() { - closed = true + override suspend fun close() { + closed = true + } } - } val client = Client( clientInfo = Implementation( @@ -203,19 +206,20 @@ class ClientTest { @Test fun `should reject due to non cancellation exception`() = runTest { var closed = false - val failingTransport = object : AbstractTransport() { - override suspend fun start() {} - - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - if (message !is JSONRPCRequest) return - check(message.method == Method.Defined.Initialize.value) - throw IllegalStateException("Test error") - } + val failingTransport = + object : AbstractTransport(backgroundScope.coroutineContext, backgroundScope.coroutineContext) { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + throw IllegalStateException("Test error") + } - override suspend fun close() { - closed = true + override suspend fun close() { + closed = true + } } - } val client = Client( clientInfo = Implementation( @@ -237,22 +241,23 @@ class ClientTest { @Test fun `should rethrow McpException as is`() = runTest { var closed = false - val failingTransport = object : AbstractTransport() { - override suspend fun start() {} - - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - if (message !is JSONRPCRequest) return - check(message.method == Method.Defined.Initialize.value) - throw McpException( - code = -32600, - message = "Invalid Request", - ) - } + val failingTransport = + object : AbstractTransport(backgroundScope.coroutineContext, backgroundScope.coroutineContext) { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + throw McpException( + code = -32600, + message = "Invalid Request", + ) + } - override suspend fun close() { - closed = true + override suspend fun close() { + closed = true + } } - } val client = Client( clientInfo = Implementation( @@ -275,22 +280,23 @@ class ClientTest { @Test fun `should rethrow StreamableHttpError as is`() = runTest { var closed = false - val failingTransport = object : AbstractTransport() { - override suspend fun start() {} - - override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { - if (message !is JSONRPCRequest) return - check(message.method == Method.Defined.Initialize.value) - throw StreamableHttpError( - code = 500, - message = "Internal Server Error", - ) - } + val failingTransport = + object : AbstractTransport(backgroundScope.coroutineContext, backgroundScope.coroutineContext) { + override suspend fun start() {} + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + if (message !is JSONRPCRequest) return + check(message.method == Method.Defined.Initialize.value) + throw StreamableHttpError( + code = 500, + message = "Internal Server Error", + ) + } - override suspend fun close() { - closed = true + override suspend fun close() { + closed = true + } } - } val client = Client( clientInfo = Implementation( diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt index 83e561844..9778961ab 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt @@ -10,6 +10,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotification import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotificationParams import io.modelcontextprotocol.kotlin.sdk.types.ToolListChangedNotification import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.test.runTest import kotlin.test.BeforeTest import kotlin.test.Test @@ -44,15 +46,16 @@ class InMemoryTransportTest { @Test fun `should send message from client to server`() = runTest { + val (client, server) = InMemoryTransport.createLinkedPair(backgroundScope.coroutineContext) val message = InitializedNotification() var receivedMessage: JSONRPCMessage? = null - serverTransport.onMessage { msg -> + server.onMessage { msg -> receivedMessage = msg } val rpcNotification = message.toJSON() - clientTransport.send(rpcNotification) + client.send(rpcNotification) assertEquals(rpcNotification, receivedMessage) } @@ -190,8 +193,11 @@ class InMemoryTransportTest { ) val receivedMessages = mutableListOf() + val mutex = Mutex() clientTransport.onMessage { msg -> - receivedMessages.add(msg) + mutex.withLock { + receivedMessages.add(msg) + } } notifications.forEach { notification -> diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt index a6c6e3df2..d4efd1334 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt @@ -1,12 +1,15 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.kotest.assertions.nondeterministic.eventually +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.PingRequest import io.modelcontextprotocol.kotlin.sdk.types.toJSON import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlin.test.fail import kotlin.time.Duration.Companion.seconds @@ -47,12 +50,15 @@ abstract class BaseTransportTest { ) val readMessages = mutableListOf() + val mutex = Mutex() val finished = CompletableDeferred() transport.onMessage { message -> - readMessages.add(message) - if (message == messages.last()) { - finished.complete(Unit) + mutex.withLock { + readMessages.add(message) + if (readMessages.size == messages.size) { + finished.complete(Unit) + } } } @@ -64,7 +70,7 @@ abstract class BaseTransportTest { finished.await() - messages shouldBe readMessages + readMessages.shouldContainExactlyInAnyOrder(messages) transport.close() } diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt index 2c1563537..566e0603a 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt @@ -1,11 +1,17 @@ package io.modelcontextprotocol.kotlin.sdk.shared import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import kotlinx.coroutines.Dispatchers +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext /** * In-memory transport for creating clients and servers that talk to each other within the same process. */ -class InMemoryTransport : AbstractTransport() { +class InMemoryTransport( + context: CoroutineContext = EmptyCoroutineContext, + handlerContext: CoroutineContext = Dispatchers.Default, +) : AbstractTransport(context, handlerContext) { private var otherTransport: InMemoryTransport? = null private val messageQueue: MutableList = mutableListOf() @@ -14,9 +20,12 @@ class InMemoryTransport : AbstractTransport() { * One should be passed to a Client and one to a Server. */ companion object { - fun createLinkedPair(): Pair { - val clientTransport = InMemoryTransport() - val serverTransport = InMemoryTransport() + fun createLinkedPair( + context: CoroutineContext = EmptyCoroutineContext, + handlerContext: CoroutineContext = Dispatchers.Default, + ): Pair { + val clientTransport = InMemoryTransport(context, handlerContext) + val serverTransport = InMemoryTransport(context, handlerContext) clientTransport.otherTransport = serverTransport serverTransport.otherTransport = clientTransport return Pair(clientTransport, serverTransport) @@ -27,7 +36,7 @@ class InMemoryTransport : AbstractTransport() { // Process any messages that were queued before start was called while (messageQueue.isNotEmpty()) { messageQueue.removeFirstOrNull()?.let { message -> - _onMessage.invoke(message) // todo? + handleMessageInline(message) // todo? } } } @@ -35,6 +44,7 @@ class InMemoryTransport : AbstractTransport() { override suspend fun close() { val other = otherTransport otherTransport = null + shutdownHandlers() other?.close() invokeOnCloseCallback() } @@ -42,6 +52,7 @@ class InMemoryTransport : AbstractTransport() { override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { val other = checkNotNull(otherTransport) { "Not connected" } - other._onMessage.invoke(message) + // necessary to propagate the caller's context - sometimes test, sometimes not + other.handleMessageInline(message) } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt index 1e88558fa..592d85606 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt @@ -113,7 +113,7 @@ abstract class AbstractAuthenticationTest { var mcpClient: Client? = null try { mcpClient = Client(Implementation(name = "test-client", version = "1.0.0")) - withTimeout(5.seconds) { + withTimeout(10.seconds) { mcpClient.connect(createClientTransport(baseUrl, VALID_USER, VALID_PASSWORD)) } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt index 7da82cc3a..968fbe020 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt @@ -9,11 +9,14 @@ import io.modelcontextprotocol.kotlin.sdk.types.ImageContent import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities import io.modelcontextprotocol.kotlin.sdk.types.TextContent import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema +import kotlinx.coroutines.CoroutineStart import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async import kotlinx.coroutines.delay import kotlinx.coroutines.launch import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext import kotlinx.serialization.json.JsonArray import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.add @@ -21,6 +24,8 @@ import kotlinx.serialization.json.buildJsonArray import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put import org.junit.jupiter.api.Test +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource import java.text.DecimalFormat import java.text.DecimalFormatSymbols import java.util.Locale @@ -158,14 +163,29 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() { put("description", "Delay in milliseconds") }, ) + put( + "blocking", + buildJsonObject { + put("type", "boolean") + put("description", "Whether to block the thread while waiting") + }, + ) }, ), ) { request -> val delay = (request.params.arguments?.get("delay") as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000 + val blocking = (request.params.arguments?.get("blocking") as? JsonPrimitive)?.content?.toBoolean() ?: false // simulate slow operation - runBlocking { - delay(delay.toLong()) + + if (blocking) { + @Suppress("RunBlockingInSuspendFunction") + runBlocking { delay(delay.toLong()) } + } else { + @Suppress("InjectDispatcher") + withContext(Dispatchers.Default) { + delay(delay.toLong()) + } } CallToolResult( @@ -691,6 +711,45 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() { actualContent shouldEqualJson expectedContent } + @ParameterizedTest + @ValueSource(booleans = [true, false]) + @Suppress("InjectDispatcher") + fun testToolConcurrentProcessing(blocking: Boolean): Unit = runBlocking(Dispatchers.Default) { + val delayMs = 1000 + val arguments = mapOf("delay" to delayMs, "blocking" to blocking) + + val startTime = System.currentTimeMillis() + + // Start a slow tool call + val deferredSlow = async { + client.callTool(slowToolName, arguments) + } + + // Give it a tiny bit of time to reach the server and start processing + delay(50) + + // Start a fast tool call + val deferredFast = async(start = CoroutineStart.UNDISPATCHED) { + client.callTool(testToolName, mapOf("text" to "fast")) + } + + val fastResult = deferredFast.await() + val fastEndTime = System.currentTimeMillis() + + // The fast tool should finish MUCH sooner than the slow tool's delay if processed concurrently + assertTrue( + fastEndTime - startTime < delayMs, + "Fast tool should finish before the slow tool's delay (took ${fastEndTime - startTime}ms)", + ) + + deferredSlow.await() + val slowEndTime = System.currentTimeMillis() + + assertTrue(slowEndTime - startTime >= delayMs, "Slow tool should take at least the specified delay") + + assertEquals("Echo: fast", (fastResult.content.first() as TextContent).text) + } + @Test fun testSpecialCharacters() { runBlocking(Dispatchers.IO) { diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt index 2f8b3872f..91643b43a 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt @@ -9,6 +9,7 @@ import io.ktor.server.engine.EmbeddedServer import io.ktor.server.engine.embeddedServer import io.ktor.server.plugins.contentnegotiation.ContentNegotiation import io.ktor.server.routing.routing +import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi import io.modelcontextprotocol.kotlin.sdk.client.Client import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport @@ -17,8 +18,10 @@ import io.modelcontextprotocol.kotlin.sdk.server.Server import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport +import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport.Configuration import io.modelcontextprotocol.kotlin.sdk.server.mcp import io.modelcontextprotocol.kotlin.sdk.server.mcpStreamableHttp +import io.modelcontextprotocol.kotlin.sdk.testing.ChannelTransport import io.modelcontextprotocol.kotlin.sdk.types.Implementation import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities @@ -40,6 +43,7 @@ import io.ktor.server.cio.CIO as ServerCIO import io.ktor.server.sse.SSE as ServerSSE @Retry(times = 3) +@OptIn(ExperimentalMcpApi::class) abstract class KotlinTestBase { protected val host = "localhost" @@ -48,9 +52,10 @@ abstract class KotlinTestBase { protected lateinit var server: Server protected lateinit var client: Client protected lateinit var serverEngine: EmbeddedServer<*, *> + protected val channelTransports = lazy { ChannelTransport.createLinkedPair() } // Transport selection - protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP } + protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP, CHANNEL } protected open val transportKind: TransportKind = TransportKind.STDIO @@ -121,6 +126,13 @@ abstract class KotlinTestBase { ) client.connect(transport) } + + TransportKind.CHANNEL -> { + client = Client( + Implementation("test", "1.0"), + ) + client.connect(channelTransports.value.clientTransport) + } } } @@ -148,7 +160,7 @@ abstract class KotlinTestBase { // Create StreamableHTTP server transport // Using JSON response mode for simpler testing (no SSE session required) val transport = StreamableHttpServerTransport( - StreamableHttpServerTransport.Configuration( + Configuration( enableJsonResponse = true, // Use JSON response mode for testing ), ) @@ -196,6 +208,10 @@ abstract class KotlinTestBase { server.createSession(serverTransport) } } + + TransportKind.CHANNEL -> { + runBlocking { server.createSession(channelTransports.value.serverTransport) } + } } } @@ -250,6 +266,12 @@ abstract class KotlinTestBase { } } } + + TransportKind.CHANNEL -> { + if (channelTransports.isInitialized()) { + runBlocking { channelTransports.value.close() } + } + } } } } diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/channel/ToolIntegrationTestChannel.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/channel/ToolIntegrationTestChannel.kt new file mode 100644 index 000000000..f0902eaf6 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/channel/ToolIntegrationTestChannel.kt @@ -0,0 +1,8 @@ +package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.channel + +import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest + +// while this isn't a "production" transport, we still want to ensure that it has the correct behavior +class ToolIntegrationTestChannel : AbstractToolIntegrationTest() { + override val transportKind: TransportKind = TransportKind.CHANNEL +} diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt index 97acc29f6..bc9efbabc 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt @@ -42,6 +42,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeoutOrNull @@ -309,7 +311,9 @@ class KotlinServerForTsClient { } } -class HttpServerTransport(private val sessionId: String) : AbstractTransport() { +@Suppress("InjectDispatcher") +class HttpServerTransport(private val sessionId: String) : + AbstractTransport(Dispatchers.Default + SupervisorJob()) { private val logger = KotlinLogging.logger {} private val pendingResponses = ConcurrentHashMap>() private val messageQueue = Channel(Channel.UNLIMITED) @@ -352,7 +356,7 @@ class HttpServerTransport(private val sessionId: String) : AbstractTransport() { logger.info { "Created deferred response for ID: $id" } logger.info { "Invoking onMessage handler" } - _onMessage.invoke(message) + handleMessage(message) logger.info { "onMessage handler completed" } try { diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTestJvm.kt similarity index 99% rename from integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt rename to integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTestJvm.kt index ea246db05..2645ae5ec 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTestJvm.kt @@ -23,7 +23,7 @@ import kotlinx.coroutines.runBlocking import kotlin.test.Test @OptIn(ExperimentalMcpApi::class) -class ChannelTransportTest { +class ChannelTransportTestJvm { @Test fun `should connect and list resources`(): Unit = runBlocking { diff --git a/integration-test/src/jvmTest/typescript/package-lock.json b/integration-test/src/jvmTest/typescript/package-lock.json index fddbba438..a7e92aa6c 100644 --- a/integration-test/src/jvmTest/typescript/package-lock.json +++ b/integration-test/src/jvmTest/typescript/package-lock.json @@ -1112,6 +1112,7 @@ "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", "dev": true, + "hasInstallScript": true, "license": "MIT", "optional": true, "os": [ diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt index 3c9d88b86..fc1d6127b 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt @@ -24,9 +24,7 @@ import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.Job -import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel -import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.ensureActive import kotlinx.coroutines.launch import kotlinx.serialization.SerializationException @@ -50,7 +48,6 @@ public class SseClientTransport( private val endpoint = CompletableDeferred() private lateinit var session: ClientSSESession - private lateinit var scope: CoroutineScope private var job: Job? = null private val origin: String by lazy { @@ -79,7 +76,6 @@ public class SseClientTransport( reconnectionTime = reconnectionTime, block = requestBuilder, ) - scope = CoroutineScope(session.coroutineContext + SupervisorJob()) job = scope.launch(CoroutineName("SseMcpClientTransport.connect#${hashCode()}")) { collectMessages() @@ -115,7 +111,7 @@ public class SseClientTransport( when (event.event) { "error" -> { val error = IllegalStateException("SSE error: ${event.data}") - _onError(error) + handleError(error) throw error } @@ -131,10 +127,10 @@ public class SseClientTransport( } catch (e: CancellationException) { throw e } catch (e: Throwable) { - _onError(e) + handleError(e) throw e } finally { - closeResources() + close() } } @@ -155,29 +151,28 @@ public class SseClientTransport( endpoint.complete(endpointUrl) logger.debug { "Client connected to endpoint: $endpointUrl" } } catch (e: Throwable) { - _onError(e) + handleError(e) endpoint.completeExceptionally(e) throw e } } - private suspend fun handleMessage(data: String) { + private fun handleMessage(data: String) { try { val message = McpJson.decodeFromString(data) - _onMessage(message) + handleMessage(message) } catch (e: SerializationException) { - _onError(e) + handleError(e) } } override suspend fun closeResources() { - job?.cancelAndJoin() try { + shutdownHandlers() if (::session.isInitialized) session.cancel() - if (::scope.isInitialized) scope.cancel() endpoint.cancel() } catch (e: Throwable) { - _onError(e) + handleError(e) } } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 12529cc1d..72b3bc98a 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -19,9 +19,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.INTERNAL_ERRO import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job -import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel @@ -112,7 +110,6 @@ public class StdioClientTransport @JvmOverloads public constructor( public enum class StderrSeverity { FATAL, WARNING, INFO, DEBUG, IGNORE } private val ioCoroutineContext: CoroutineContext = IODispatcher - private val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) override suspend fun initialize() { logger.debug { "Starting StdioClientTransport..." } @@ -170,7 +167,7 @@ public class StdioClientTransport @JvmOverloads public constructor( .collect { event -> when (event) { is Event.JsonRpc -> { - handleJSONRPCMessage(event.message) + handleMessage(event.message) } is Event.StderrEvent -> { @@ -178,11 +175,11 @@ public class StdioClientTransport @JvmOverloads public constructor( when (errorSeverity) { FATAL -> { runCatching { - _onError( + handleError( McpException(INTERNAL_ERROR, "Message in StdErr: ${event.message}"), ) } - stopProcessing("Fatal STDERR message received") + cancelProcessing("Fatal STDERR message received") } WARNING -> { @@ -205,13 +202,13 @@ public class StdioClientTransport @JvmOverloads public constructor( is Event.EOFEvent -> { if (event.stream == ProcessStream.Stdin) { - stopProcessing("EOF in ${event.stream}") + cancelProcessing("EOF in ${event.stream}") } } is Event.IOErrorEvent -> { - runCatching { _onError(event.cause) } - stopProcessing("IO Error", event.cause) + handleError(event.cause) + cancelProcessing("IO Error", event.cause) } } } @@ -241,8 +238,7 @@ public class StdioClientTransport @JvmOverloads public constructor( } override suspend fun closeResources() { - scope.stopProcessing("Closed") - scope.coroutineContext[Job]?.join() // Wait for all coroutines to complete + finishProcessing() } private fun sendOutboundMessage(message: JSONRPCMessage, sink: Sink, mainScope: CoroutineScope) { @@ -252,26 +248,24 @@ public class StdioClientTransport @JvmOverloads public constructor( sink.flush() } catch (e: SerializationException) { logger.warn(e) { "Can't serialize message" } - runCatching { _onError(McpException(INTERNAL_ERROR, "Serialization error")) } - mainScope.stopProcessing("Can't serialize message", e) + handleError(McpException(INTERNAL_ERROR, "Serialization error")) + mainScope.cancelProcessing("Can't serialize message", e) } catch (e: IOException) { logger.warn(e) { "Can't send message" } - runCatching { _onError(McpException(CONNECTION_CLOSED, "Can't send message. Connection closed")) } - mainScope.stopProcessing("Write I/O failed", e) + handleError(McpException(CONNECTION_CLOSED, "Can't send message. Connection closed")) + mainScope.cancelProcessing("Write I/O failed", e) } } - private suspend fun handleJSONRPCMessage(msg: JSONRPCMessage) { - try { - _onMessage.invoke(msg) - } catch (e: Throwable) { - logger.error(e) { "Error processing message." } - runCatching { _onError.invoke(e) } - } + private suspend fun finishProcessing() { + sendChannel.close() // Stop accepting new messages + shutdownHandlers() + invokeOnCloseCallback() } - private fun CoroutineScope.stopProcessing(reason: String, cause: Throwable? = null) { + private fun CoroutineScope.cancelProcessing(reason: String, cause: Throwable? = null) { sendChannel.close() // Stop accepting new messages + cancelInProgressHandlers(reason, cause) invokeOnCloseCallback() cancel(reason, cause) // cancel current coroutine context } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt index 80f37aedc..18d030e4f 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StreamableHttpClientTransport.kt @@ -32,11 +32,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.McpJson import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CoroutineName -import kotlinx.coroutines.CoroutineScope -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job -import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.delay import kotlinx.coroutines.isActive @@ -99,8 +95,6 @@ public class StreamableHttpClientTransport( private var sseJob: Job? = null - private val scope by lazy { CoroutineScope(SupervisorJob() + Dispatchers.Default) } - /** Result of an SSE stream collection. Reconnect when [hasPrimingEvent] is true and [receivedResponse] is false. */ private data class SseStreamResult( val hasPrimingEvent: Boolean, @@ -150,16 +144,18 @@ public class StreamableHttpClientTransport( if (!response.status.isSuccess()) { val error = StreamableHttpError(response.status.value, response.bodyAsText()) - _onError(error) + handleError(error) throw error } when (response.contentType()?.withoutParameters()) { ContentType.Application.Json -> response.bodyAsText().takeIf { it.isNotEmpty() }?.let { json -> runCatching { McpJson.decodeFromString(json) } - .onSuccess { _onMessage(it) } + .onSuccess { + handleMessage(it) + } .onFailure { - _onError(it) + handleError(it) throw it } } @@ -183,7 +179,7 @@ public class StreamableHttpClientTransport( val ct = response.contentType()?.toString() ?: "" val error = StreamableHttpError(-1, "Unexpected content type: $ct") - _onError(error) + handleError(error) throw error } } @@ -207,8 +203,8 @@ public class StreamableHttpClientTransport( override suspend fun closeResources() { logger.debug { "Client transport closing." } + shutdownHandlers() sseJob?.cancelAndJoin() - scope.cancel() } /** @@ -229,7 +225,7 @@ public class StreamableHttpClientTransport( "Failed to terminate session: ${response.status.description}", ) logger.error(error) { "Failed to terminate session" } - _onError(error) + handleError(error) throw error } @@ -275,7 +271,7 @@ public class StreamableHttpClientTransport( ConnectResult.Failed -> { // Give up after maxRetries consecutive failed connection attempts if (++attempt >= reconnectionOptions.maxRetries) { - _onError(StreamableHttpError(null, "Maximum reconnection attempts exceeded")) + handleError(StreamableHttpError(null, "Maximum reconnection attempts exceeded")) return@launch } continue @@ -384,21 +380,21 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) + handleMessage(msg.copy(id = replayMessageId)) } else { - _onMessage(msg) + handleMessage(msg) } } - .onFailure(_onError) + .onFailure(::handleError) } - "error" -> _onError(StreamableHttpError(null, event.data)) + "error" -> handleError(StreamableHttpError(null, event.data)) } } } catch (_: CancellationException) { // ignore } catch (t: Throwable) { - _onError(t) + handleError(t) } return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } @@ -420,7 +416,7 @@ public class StreamableHttpClientTransport( var id: String? = null var eventName: String? = null - suspend fun dispatch(id: String?, eventName: String?, data: String) { + fun dispatch(id: String?, eventName: String?, data: String) { id?.let { localLastEventId = it hasPrimingEvent = true @@ -434,18 +430,18 @@ public class StreamableHttpClientTransport( .onSuccess { msg -> if (msg is JSONRPCResponse) receivedResponse = true if (replayMessageId != null && msg is JSONRPCResponse) { - _onMessage(msg.copy(id = replayMessageId)) + handleMessage(msg.copy(id = replayMessageId)) } else { - _onMessage(msg) + handleMessage(msg) } } .onFailure { - _onError(it) + handleError(it) throw it } } if (eventName == "error") { - _onError(StreamableHttpError(null, data)) + handleError(StreamableHttpError(null, data)) return } } @@ -466,13 +462,17 @@ public class StreamableHttpClientTransport( line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() - line.startsWith("data:") -> sb.append(line.substringAfter("data:").trim()) + line.startsWith("data:") -> { + if (sb.isNotEmpty()) sb.append("\n") + sb.append(line.substringAfter("data:").removePrefix(" ")) + } line.startsWith("retry:") -> line.substringAfter("retry:").trim().toLongOrNull()?.let { localServerRetryDelay = it.milliseconds } } } + dispatch(id = id, eventName = eventName, data = sb.toString()) return SseStreamResult(hasPrimingEvent, receivedResponse, localLastEventId, localServerRetryDelay) } } diff --git a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt index 55df3e06e..9ec04f2f4 100644 --- a/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt +++ b/kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/streamable/http/StreamableHttpClientTransportTest.kt @@ -31,6 +31,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.RequestId import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.delay +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withContext import kotlinx.coroutines.withTimeout @@ -243,6 +245,7 @@ class StreamableHttpClientTransportTest { @Test fun testNotificationSchemaE2E() = runTest { val receivedMessages = mutableListOf() + val mutex = Mutex() var sseStarted = false val transport = createTransport { request -> @@ -300,7 +303,7 @@ class StreamableHttpClientTransportTest { } transport.onMessage { message -> - receivedMessages.add(message) + mutex.withLock { receivedMessages.add(message) } } transport.start() @@ -514,12 +517,15 @@ class StreamableHttpClientTransportTest { } val receivedMessages = mutableListOf() + val mutex = Mutex() val twoMessagesReceived = CompletableDeferred() transport.onMessage { message -> - receivedMessages.add(message) - if (receivedMessages.size >= 2 && !twoMessagesReceived.isCompleted) { - twoMessagesReceived.complete(Unit) + mutex.withLock { + receivedMessages.add(message) + if (receivedMessages.size >= 2 && !twoMessagesReceived.isCompleted) { + twoMessagesReceived.complete(Unit) + } } } @@ -685,12 +691,15 @@ class StreamableHttpClientTransportTest { } val receivedMessages = mutableListOf() + val mutex = Mutex() val responseReceived = CompletableDeferred() transport.onMessage { message -> - receivedMessages.add(message) - if (message is JSONRPCResponse && !responseReceived.isCompleted) { - responseReceived.complete(Unit) + mutex.withLock { + receivedMessages.add(message) + if (message is JSONRPCResponse && !responseReceived.isCompleted) { + responseReceived.complete(Unit) + } } } @@ -748,12 +757,15 @@ class StreamableHttpClientTransportTest { } val receivedMessages = mutableListOf() + val mutex = Mutex() val twoMessagesReceived = CompletableDeferred() transport.onMessage { message -> - receivedMessages.add(message) - if (receivedMessages.size >= 2 && !twoMessagesReceived.isCompleted) { - twoMessagesReceived.complete(Unit) + mutex.withLock { + receivedMessages.add(message) + if (receivedMessages.size >= 2 && !twoMessagesReceived.isCompleted) { + twoMessagesReceived.complete(Unit) + } } } @@ -804,12 +816,15 @@ class StreamableHttpClientTransportTest { } val receivedMessages = mutableListOf() + val mutex = Mutex() val responseReceived = CompletableDeferred() transport.onMessage { message -> - receivedMessages.add(message) - if (message is JSONRPCResponse && !responseReceived.isCompleted) { - responseReceived.complete(Unit) + mutex.withLock { + receivedMessages.add(message) + if (message is JSONRPCResponse && !responseReceived.isCompleted) { + responseReceived.complete(Unit) + } } } @@ -860,10 +875,11 @@ class StreamableHttpClientTransportTest { transport: StreamableHttpClientTransport, ): Pair, MutableList> { val receivedMessages = mutableListOf() + val mutex = Mutex() val receivedErrors = mutableListOf() transport.onMessage { message -> - receivedMessages.add(message) + mutex.withLock { receivedMessages.add(message) } } transport.onError { error -> diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index af22323e2..ff9ee99d2 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -413,6 +413,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/internal/Utils_jvmKt { public abstract class io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public fun ()V + public fun (Lkotlin/coroutines/CoroutineContext;)V + public synthetic fun (Lkotlin/coroutines/CoroutineContext;ILkotlin/jvm/internal/DefaultConstructorMarker;)V protected final fun checkState (Lio/modelcontextprotocol/kotlin/sdk/shared/ClientTransportState;Lkotlin/jvm/functions/Function1;)V public static synthetic fun checkState$default (Lio/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport;Lio/modelcontextprotocol/kotlin/sdk/shared/ClientTransportState;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -430,12 +432,19 @@ public abstract class io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTr public abstract class io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport : io/modelcontextprotocol/kotlin/sdk/shared/Transport { public fun ()V - protected final fun get_onError ()Lkotlin/jvm/functions/Function1; - protected final fun get_onMessage ()Lkotlin/jvm/functions/Function2; + public fun (Lkotlin/coroutines/CoroutineContext;)V + public synthetic fun (Lkotlin/coroutines/CoroutineContext;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + protected final fun cancelInProgressHandlers (Ljava/lang/String;Ljava/lang/Throwable;)V + protected final fun getScope ()Lkotlinx/coroutines/CoroutineScope; + protected final fun handleError (Ljava/lang/Throwable;)V + protected final fun handleMessage (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;)Lkotlinx/coroutines/Job; + protected final fun handleMessageInline (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; protected final fun invokeOnCloseCallback ()V + protected final fun joinInProgressHandlers (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public fun onClose (Lkotlin/jvm/functions/Function0;)V public fun onError (Lkotlin/jvm/functions/Function1;)V public fun onMessage (Lkotlin/jvm/functions/Function2;)V + protected final fun shutdownHandlers (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class io/modelcontextprotocol/kotlin/sdk/shared/ClientTransportState : java/lang/Enum { @@ -557,6 +566,8 @@ public class io/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions { public abstract class io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport { public fun ()V + public fun (Lkotlin/coroutines/CoroutineContext;)V + public synthetic fun (Lkotlin/coroutines/CoroutineContext;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; protected abstract fun getSession ()Lio/ktor/websocket/WebSocketSession; protected abstract fun initializeSession (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; @@ -2245,6 +2256,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/JsonRpcKt { public static final field JSONRPC_VERSION Ljava/lang/String; public static final fun RequestId (J)Lio/modelcontextprotocol/kotlin/sdk/types/RequestId; public static final fun RequestId (Ljava/lang/String;)Lio/modelcontextprotocol/kotlin/sdk/types/RequestId; + public static final fun asString (Lio/modelcontextprotocol/kotlin/sdk/types/RequestId;)Ljava/lang/String; public static final fun fromJSON (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCRequest;)Lio/modelcontextprotocol/kotlin/sdk/types/Request; public static final fun toJSON (Lio/modelcontextprotocol/kotlin/sdk/types/Notification;)Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCNotification; public static final fun toJSON (Lio/modelcontextprotocol/kotlin/sdk/types/Request;)Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCRequest; diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport.kt index 5f679992e..5b0d04f97 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransport.kt @@ -6,8 +6,11 @@ import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.CONNECTION_CLOSED import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.INTERNAL_ERROR +import kotlinx.coroutines.Dispatchers import kotlin.concurrent.atomics.AtomicReference import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext import kotlin.coroutines.cancellation.CancellationException /** @@ -22,7 +25,10 @@ import kotlin.coroutines.cancellation.CancellationException * async workflows, ensuring proper resource cleanup in the case of failure or cancellation. */ @OptIn(ExperimentalAtomicApi::class, InternalMcpApi::class) -public abstract class AbstractClientTransport : AbstractTransport() { +public abstract class AbstractClientTransport( + context: CoroutineContext = EmptyCoroutineContext, + handlerContext: CoroutineContext = Dispatchers.Default, +) : AbstractTransport(context, handlerContext) { protected abstract val logger: KLogger @@ -176,7 +182,7 @@ public abstract class AbstractClientTransport : AbstractTransport() { } catch (e: CancellationException) { throw e // Always propagate cancellation } catch (e: Exception) { - _onError(e) + handleError(e) if (e is McpException) { throw e } else { diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt index 7b881004e..1fd2ded62 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport.kt @@ -1,29 +1,116 @@ package io.modelcontextprotocol.kotlin.sdk.shared +import io.github.oshai.kotlinlogging.KLogger +import io.github.oshai.kotlinlogging.KotlinLogging import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.coroutineName +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.update +import kotlinx.collections.immutable.persistentSetOf +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineExceptionHandler +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.Job +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.launch +import kotlinx.coroutines.withContext import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext /** * Implements [onClose], [onError] and [onMessage] functions of [Transport] providing * corresponding [_onClose], [_onError] and [_onMessage] properties to use for an implementation. */ @OptIn(ExperimentalAtomicApi::class) -@Suppress("PropertyName") -public abstract class AbstractTransport : Transport { +public abstract class AbstractTransport( + context: CoroutineContext = EmptyCoroutineContext, + protected val handlerContext: CoroutineContext = Dispatchers.Default, +) : Transport { + private val logger: KLogger = KotlinLogging.logger {} + private val onCloseCalled = AtomicBoolean(false) private var _onClose: (() -> Unit) = {} - protected var _onError: ((Throwable) -> Unit) = {} - private set + private var _onError: ((Throwable) -> Unit) = {} // to not skip messages private val _onMessageInitialized = CompletableDeferred() - protected var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = { + private var _onMessage: (suspend ((JSONRPCMessage) -> Unit)) = { _onMessageInitialized.await() _onMessage.invoke(it) } - private set + + private val inProgressRequests = atomic(persistentSetOf()) + + @Suppress("InjectDispatcher") + protected val scope: CoroutineScope = + CoroutineScope( + Dispatchers.Default + context + SupervisorJob(context[Job]) + CoroutineExceptionHandler { ctx, e -> + logger.error(e) { + "Uncaught error in transport scope from ${ctx[CoroutineName] ?: "unknown coroutine"}" + } + handleError(e) + }, + ) + + @Suppress("TooGenericExceptionCaught") + protected fun handleError(error: Throwable) { + if (error is CancellationException) { + throw error + } + + try { + _onError.invoke(error) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + error.addSuppressed(e) + logger.error(e) { "Failed to invoke error handler for $error" } + } + } + + protected fun handleMessage(message: JSONRPCMessage): Job { + val name = message.coroutineName + return scope.launch(handlerContext + name) { + try { + doHandle(message, name) + } catch (e: CancellationException) { + throw e + } catch (_: Throwable) { + // Already handled in doHandle via _onError + } + }.also { job -> + inProgressRequests.update { it.add(job) } + job.invokeOnCompletion { _ -> + inProgressRequests.update { it.remove(job) } + } + } + } + + protected suspend fun handleMessageInline(message: JSONRPCMessage) { + val name = message.coroutineName + withContext(name) { + doHandle(message, name) + } + } + + private suspend fun doHandle(message: JSONRPCMessage, name: CoroutineName) { + @Suppress("TooGenericExceptionCaught") + try { + _onMessage.invoke(message) + } catch (e: CancellationException) { + throw e + } catch (e: Throwable) { + logger.error(e) { "Error processing message ${name.name}." } + handleError(e) + throw e + } + } override fun onClose(block: () -> Unit) { val old = _onClose @@ -42,17 +129,31 @@ public abstract class AbstractTransport : Transport { } override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) { - val old: suspend (JSONRPCMessage) -> Unit = when (_onMessageInitialized.isCompleted) { - true -> _onMessage - false -> { _ -> } + if (_onMessageInitialized.isCompleted) { + val old = _onMessage + _onMessage = { message -> + old(message) + block(message) + } + } else { + _onMessage = block + _onMessageInitialized.complete(Unit) } + } - _onMessage = { message -> - old(message) - block(message) - } + protected suspend fun joinInProgressHandlers(): Unit = + inProgressRequests.getAndSet(persistentSetOf()).forEach { runCatching { it.join() } } + + protected fun cancelInProgressHandlers(message: String, error: Throwable?) { + inProgressRequests.getAndSet(persistentSetOf()).forEach { it.cancel(message, error) } + } - _onMessageInitialized.complete(Unit) + /** + * Helper to safely cancel and join all in-progress handlers. + */ + protected suspend fun shutdownHandlers() { + cancelInProgressHandlers("Closing", null) + joinInProgressHandlers() } /** @@ -62,9 +163,12 @@ public abstract class AbstractTransport : Transport { * an atomic flag (`onCloseCalled`). If the callback has already been executed, * the method does nothing. Any exceptions thrown during the execution of the * `_onClose` callback are caught and suppressed. + * + * Note: This method automatically cancels the transport's internal [scope]. */ protected fun invokeOnCloseCallback() { if (onCloseCalled.compareAndSet(expectedValue = false, newValue = true)) { + scope.cancel() runCatching { _onClose() } } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index e9d833a2e..082c41623 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -355,11 +355,13 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio val handler = _progressHandlers.value[progressToken] if (handler == null) { - val error = Error( - "Received a progress notification for an unknown token: ${McpJson.encodeToString(notification)}", - ) - logger.error { error.message } - onError(error) + logger.warn { + "Received a progress notification for an unknown or missing token: ${ + McpJson.encodeToString( + notification, + ) + }. It may have arrived after the response was processed." + } return } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt index de75cc460..3bfe8a9c2 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt @@ -8,14 +8,14 @@ import io.ktor.websocket.readText import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import io.modelcontextprotocol.kotlin.sdk.types.McpJson import kotlinx.coroutines.CoroutineName -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.InternalCoroutinesApi -import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.channels.ClosedReceiveChannelException import kotlinx.coroutines.job import kotlinx.coroutines.launch import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext public const val MCP_SUBPROTOCOL: String = "mcp" @@ -26,10 +26,8 @@ private val logger = KotlinLogging.logger {} * Handles communication over a WebSocket session. */ @OptIn(ExperimentalAtomicApi::class) -public abstract class WebSocketMcpTransport : AbstractTransport() { - private val scope by lazy { - CoroutineScope(session.coroutineContext + SupervisorJob()) - } +public abstract class WebSocketMcpTransport(context: CoroutineContext = EmptyCoroutineContext) : + AbstractTransport(context) { private val initialized: AtomicBoolean = AtomicBoolean(false) @@ -66,15 +64,15 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { if (message !is Frame.Text) { val e = IllegalArgumentException("Expected text frame, got ${message::class.simpleName}: $message") - _onError.invoke(e) + handleError(e) throw e } try { val message = McpJson.decodeFromString(message.readText()) - _onMessage.invoke(message) + handleMessage(message) } catch (e: Exception) { - _onError.invoke(e) + handleError(e) throw e } } @@ -83,7 +81,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { @OptIn(InternalCoroutinesApi::class) session.coroutineContext.job.invokeOnCompletion { if (it != null) { - _onError.invoke(it) + handleError(it) } else { invokeOnCloseCallback() } @@ -105,7 +103,9 @@ public abstract class WebSocketMcpTransport : AbstractTransport() { } logger.debug { "Closing websocket session" } + shutdownHandlers() session.close() session.coroutineContext.job.join() + invokeOnCloseCallback() } } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt index 68b2f0274..56aef82ac 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/jsonRpc.kt @@ -2,10 +2,13 @@ package io.modelcontextprotocol.kotlin.sdk.types +import kotlinx.coroutines.CoroutineName import kotlinx.serialization.EncodeDefault import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject +import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.decodeFromJsonElement import kotlinx.serialization.json.encodeToJsonElement import kotlin.concurrent.atomics.ExperimentalAtomicApi @@ -16,6 +19,21 @@ import kotlin.uuid.Uuid public const val JSONRPC_VERSION: String = "2.0" +private val COROUTINE_NAME_CACHE: Map = + Method.Defined.entries.associate { it.value to CoroutineName("mcp-${it.value}") } + +/** + * Returns a [CoroutineName] for this message. + */ +internal val JSONRPCMessage.coroutineName: CoroutineName + get() = when (this) { + is JSONRPCError -> CoroutineName("mcp-error-${id?.asString() ?: "unknown"}") + is JSONRPCNotification -> COROUTINE_NAME_CACHE[method] ?: CoroutineName("mcp-notification-$method") + is JSONRPCRequest -> CoroutineName("mcp-request-${id.asString()}") + is JSONRPCResponse -> CoroutineName("mcp-response-${id.asString()}") + JSONRPCEmptyMessage -> CoroutineName("mcp-empty") + } + /** * Creates a `RequestId` instance using the provided string value. * @@ -49,6 +67,11 @@ public sealed interface RequestId { public value class NumberId(public val value: Long) : RequestId } +public fun RequestId.asString(): String = when (this) { + is RequestId.StringId -> value + is RequestId.NumberId -> value.toString() +} + /** * Converts the request to a JSON-RPC request. * @@ -66,8 +89,15 @@ public fun Request.toJSON(): JSONRPCRequest = JSONRPCRequest( * * @return The decoded [Request] */ -public fun JSONRPCRequest.fromJSON(): Request = - McpJson.decodeFromJsonElement(McpJson.encodeToJsonElement(this)) +public fun JSONRPCRequest.fromJSON(): Request { + val map = buildMap(2) { + put("method", JsonPrimitive(method)) + if (params != null) { + put("params", params) + } + } + return McpJson.decodeFromJsonElement(JsonObject(map)) +} /** * Converts the notification to a JSON-RPC notification. @@ -86,8 +116,15 @@ public fun Notification.toJSON(): JSONRPCNotification = JSONRPCNotification( * * @return The decoded [Notification]. */ -internal fun JSONRPCNotification.fromJSON(): Notification = - McpJson.decodeFromJsonElement(McpJson.encodeToJsonElement(this)) +internal fun JSONRPCNotification.fromJSON(): Notification { + val map = buildMap(2) { + put("method", JsonPrimitive(method)) + if (params != null) { + put("params", params) + } + } + return McpJson.decodeFromJsonElement(JsonObject(map)) +} /** * Base interface for all JSON-RPC 2.0 messages. diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt index 4d5974d92..a34b28a49 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/serializers.kt @@ -11,6 +11,7 @@ import kotlinx.serialization.encoding.Decoder import kotlinx.serialization.encoding.Encoder import kotlinx.serialization.json.JsonContentPolymorphicSerializer import kotlinx.serialization.json.JsonElement +import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonPrimitive @@ -379,6 +380,9 @@ internal object ServerResultPolymorphicSerializer : internal object JSONRPCMessagePolymorphicSerializer : JsonContentPolymorphicSerializer(JSONRPCMessage::class) { override fun selectDeserializer(element: JsonElement): DeserializationStrategy { + if (element !is JsonObject) { + throw SerializationException("JSONRPCMessage must be a JSON object, but was ${element::class.simpleName}") + } val jsonObj = element.jsonObject return when { "error" in jsonObj -> JSONRPCError.serializer() diff --git a/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransportTest.kt b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransportTest.kt index 1d144d861..551f28daa 100644 --- a/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransportTest.kt +++ b/kotlin-sdk-core/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/AbstractClientTransportTest.kt @@ -448,6 +448,7 @@ class AbstractClientTransportTest { @OptIn(InternalMcpApi::class) private class TestClientTransport : AbstractClientTransport() { override val logger: KLogger = KotlinLogging.logger {} + val sentMessages = mutableListOf() var lastSendOptions: TransportSendOptions? = null var initializeCalled = false diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt index 619266aa7..736408a6a 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt @@ -29,7 +29,7 @@ internal const val SESSION_ID_PARAM = "sessionId" */ @OptIn(ExperimentalAtomicApi::class) public class SseServerTransport(private val endpoint: String, private val session: ServerSSESession) : - AbstractTransport() { + AbstractTransport(session.call.coroutineContext) { private val initialized: AtomicBoolean = AtomicBoolean(false) @OptIn(ExperimentalUuidApi::class) @@ -56,7 +56,7 @@ public class SseServerTransport(private val endpoint: String, private val sessio @OptIn(InternalCoroutinesApi::class) session.coroutineContext.job.invokeOnCompletion { if (it != null && it !is CancellationException) { - _onError.invoke(it) + handleError(it) } else { invokeOnCloseCallback() } @@ -72,7 +72,7 @@ public class SseServerTransport(private val endpoint: String, private val sessio if (!initialized.load()) { val message = "SSE connection not established" call.respondText(message, status = HttpStatusCode.InternalServerError) - _onError.invoke(IllegalStateException(message)) + handleError(IllegalStateException(message)) } val body = try { @@ -84,7 +84,7 @@ public class SseServerTransport(private val endpoint: String, private val sessio call.receiveText() } catch (e: Exception) { call.respondText("Invalid message: ${e.message}", status = HttpStatusCode.BadRequest) - _onError.invoke(e) + handleError(e) return } @@ -103,16 +103,17 @@ public class SseServerTransport(private val endpoint: String, private val sessio * This can be used to inform the server of messages that arrive via a means different from HTTP POST. */ public suspend fun handleMessage(message: String) { - try { - val parsedMessage = McpJson.decodeFromString(message) - _onMessage.invoke(parsedMessage) + val parsedMessage = try { + McpJson.decodeFromString(message) } catch (e: Exception) { - _onError.invoke(e) + handleError(e) throw e } + handleMessageInline(parsedMessage) } override suspend fun close() { + shutdownHandlers() session.close() invokeOnCloseCallback() } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt index 8fd46dd21..38c3b4f50 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt @@ -8,10 +8,8 @@ import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Job import kotlinx.coroutines.NonCancellable -import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.isActive @@ -25,7 +23,6 @@ import kotlinx.io.readByteArray import kotlinx.io.writeString import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi -import kotlin.coroutines.CoroutineContext private const val READ_BUFFER_SIZE = 8192L @@ -48,8 +45,6 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: private var sendingJob: Job? = null private var processingJob: Job? = null - private val coroutineContext: CoroutineContext = IODispatcher + SupervisorJob() - private val scope = CoroutineScope(coroutineContext) private val readChannel = Channel(Channel.UNLIMITED) private val writeChannel = Channel(Channel.UNLIMITED) private val outputSink = outputStream.buffered() @@ -70,7 +65,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: } private fun launchReadingJob(): Job { - val job = scope.launch { + val job = scope.launch(IODispatcher) { val buf = Buffer() @Suppress("TooGenericExceptionCaught") try { @@ -89,7 +84,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: throw e } catch (e: Throwable) { logger.error(e) { "Error reading from stdin" } - _onError.invoke(e) + handleError(e) } finally { // Reached EOF or error, close connection close() @@ -112,7 +107,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: } catch (e: CancellationException) { throw e } catch (e: Throwable) { - _onError.invoke(e) + handleError(e) } } job.invokeOnCompletion { cause -> @@ -122,7 +117,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: } private fun launchSendingJob(): Job { - val job = scope.launch { + val job = scope.launch(IODispatcher) { @Suppress("TooGenericExceptionCaught") try { for (message in writeChannel) { @@ -134,7 +129,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: throw e } catch (e: Throwable) { logger.error(e) { "Error writing to stdout" } - _onError.invoke(e) + handleError(e) } } job.invokeOnCompletion { cause -> @@ -146,25 +141,24 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: return job } - private suspend fun processReadBuffer() { + private fun processReadBuffer() { @Suppress("TooGenericExceptionCaught") while (true) { val message = try { readBuffer.readMessage() } catch (e: Throwable) { - _onError.invoke(e) + handleError(e) null } if (message == null) break - // Async invocation broke delivery order try { - _onMessage.invoke(message) + handleMessage(message) } catch (e: CancellationException) { throw e } catch (e: Throwable) { logger.error(e) { "Error processing message" } - _onError.invoke(e) + handleError(e) } } } @@ -188,6 +182,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream: if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return withContext(NonCancellable) { + shutdownHandlers() writeChannel.close() sendingJob?.cancelAndJoin() diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index d0bbcfbbf..46986cc2c 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -30,6 +30,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS import kotlinx.coroutines.awaitCancellation import kotlinx.coroutines.job +import kotlinx.coroutines.joinAll import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.serialization.json.JsonArray @@ -291,6 +292,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } catch (_: Exception) { } } + shutdownHandlers() streamsMapping.clear() requestToStreamMapping.clear() requestToResponseMapping.clear() @@ -304,7 +306,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat public suspend fun handleRequest(session: ServerSSESession?, call: ApplicationCall) { validateHeaders(call)?.let { reason -> call.reject(HttpStatusCode.Forbidden, RPCError.ErrorCode.CONNECTION_CLOSED, reason) - _onError(Error(reason)) + handleError(Error(reason)) return } @@ -390,7 +392,9 @@ public class StreamableHttpServerTransport(private val configuration: Configurat val hasRequest = messages.any { it is JSONRPCRequest } if (!hasRequest) { call.respondNullable(status = HttpStatusCode.Accepted, message = null) - messages.forEach { message -> _onMessage(message) } + messages.map { message -> + handleMessage(message) + }.joinAll() return } @@ -407,14 +411,16 @@ public class StreamableHttpServerTransport(private val configuration: Configurat } call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) } - messages.forEach { message -> _onMessage(message) } + messages.map { message -> + handleMessage(message) + }.joinAll() } catch (e: Exception) { call.reject( HttpStatusCode.BadRequest, RPCError.ErrorCode.PARSE_ERROR, "Parse error: ${e.message}", ) - _onError(e) + handleError(e) } } @@ -485,7 +491,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat try { sessionContext.session?.close() } catch (e: Exception) { - _onError(e) + handleError(e) } finally { streamsMapping.remove(streamId) } @@ -539,7 +545,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat data = McpJson.encodeToString(message), ) } catch (e: Exception) { - _onError(IllegalStateException("Failed to replay event: ${e.message}", e)) + handleError(IllegalStateException("Failed to replay event: ${e.message}", e)) } } @@ -547,10 +553,10 @@ public class StreamableHttpServerTransport(private val configuration: Configurat session.coroutineContext.job.invokeOnCompletion { throwable -> streamsMapping.remove(streamId) - throwable?.let { _onError(it) } + throwable?.let { handleError(it) } } } catch (e: Exception) { - _onError(e) + handleError(e) } } @@ -655,7 +661,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat try { session?.send(data = "") } catch (e: Exception) { - _onError(e) + handleError(e) } } @@ -729,7 +735,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat data = "", ) } catch (e: Exception) { - _onError(e) + handleError(e) } } diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt index 0c0cc78c7..f654c9301 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpServerTransport.kt @@ -13,7 +13,8 @@ private val logger = KotlinLogging.logger {} * * @property session The WebSocket server session used for communication. */ -public class WebSocketMcpServerTransport(override val session: WebSocketServerSession) : WebSocketMcpTransport() { +public class WebSocketMcpServerTransport(override val session: WebSocketServerSession) : + WebSocketMcpTransport(session.coroutineContext) { override suspend fun initializeSession() { logger.debug { "Checking session headers" } val subprotocol = session.call.request.headers[HttpHeaders.SecWebSocketProtocol] diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractKtorExtensionsTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractKtorExtensionsTest.kt index 28e21b9c5..a61afee3d 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractKtorExtensionsTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractKtorExtensionsTest.kt @@ -13,6 +13,8 @@ import io.ktor.http.contentType import io.ktor.utils.io.readUTF8Line import io.modelcontextprotocol.kotlin.sdk.types.Implementation import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import kotlinx.coroutines.withTimeout +import kotlin.time.Duration.Companion.seconds @Suppress("AbstractClassCanBeConcreteClass") abstract class AbstractKtorExtensionsTest { @@ -31,35 +33,37 @@ abstract class AbstractKtorExtensionsTest { * - POST without a sessionId returns 400 Bad Request */ protected suspend fun HttpClient.assertMcpEndpointsAt(path: String) { - prepareGet(path).execute { response -> - response.shouldHaveStatus(HttpStatusCode.OK) - response.shouldHaveContentType(sseContentType) + withTimeout(30.seconds) { + prepareGet(path).execute { response -> + response.shouldHaveStatus(HttpStatusCode.OK) + response.shouldHaveContentType(sseContentType) - // Extract sessionId from the SSE "endpoint" event - val channel = response.bodyAsChannel() - var eventName: String? = null - var sessionId: String? = null + // Extract sessionId from the SSE "endpoint" event + val channel = response.bodyAsChannel() + var eventName: String? = null + var sessionId: String? = null - while (sessionId == null && !channel.isClosedForRead) { - val line = channel.readUTF8Line() ?: break - when { - line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() + while (sessionId == null && !channel.isClosedForRead) { + val line = channel.readUTF8Line() ?: break + when { + line.startsWith("event:") -> eventName = line.substringAfter("event:").trim() - line.startsWith("data:") && eventName == "endpoint" -> { - val data = line.substringAfter("data:").trim() - sessionId = data.substringAfter("sessionId=").ifEmpty { null } + line.startsWith("data:") && eventName == "endpoint" -> { + val data = line.substringAfter("data:").trim() + sessionId = data.substringAfter("sessionId=").ifEmpty { null } + } } } - } - requireNotNull(sessionId) { "sessionId not found in SSE endpoint event" } + requireNotNull(sessionId) { "sessionId not found in SSE endpoint event" } - // POST a valid JSON-RPC ping while the SSE connection is alive - val postResponse = post("$path?sessionId=$sessionId") { - contentType(ContentType.Application.Json) - setBody("""{"jsonrpc":"2.0","id":1,"method":"ping"}""") + // POST a valid JSON-RPC ping while the SSE connection is alive + val postResponse = post("$path?sessionId=$sessionId") { + contentType(ContentType.Application.Json) + setBody("""{"jsonrpc":"2.0","id":1,"method":"ping"}""") + } + postResponse.shouldHaveStatus(HttpStatusCode.Accepted) } - postResponse.shouldHaveStatus(HttpStatusCode.Accepted) } // POST without sessionId returns 400 Bad Request diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt index dcbb7a4b8..932ae6dd1 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt @@ -4,17 +4,24 @@ import io.kotest.assertions.nondeterministic.eventually import io.kotest.assertions.throwables.shouldThrow import io.kotest.assertions.withClue import io.kotest.matchers.collections.shouldContain +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest import io.modelcontextprotocol.kotlin.sdk.types.PingRequest +import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.toJSON import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withTimeout import kotlinx.io.Buffer import kotlinx.io.RawSink @@ -136,12 +143,15 @@ class StdioServerTransportTest { ) val readMessages = mutableListOf() + val mutex = Mutex() val finished = CompletableDeferred() server.onMessage { message -> - readMessages.add(message) - if (message == messages[1]) { - finished.complete(Unit) + mutex.withLock { + readMessages.add(message) + if (readMessages.size == messages.size) { + finished.complete(Unit) + } } } @@ -154,7 +164,7 @@ class StdioServerTransportTest { server.start() finished.await() - readMessages shouldBe messages + readMessages.shouldContainExactlyInAnyOrder(messages) } // region: Exception handling @@ -221,20 +231,25 @@ class StdioServerTransportTest { @MethodSource("handlerErrors") fun `should continue processing messages after handler throws`(throwable: Throwable) = runIntegrationTest { val server = StdioServerTransport(bufferedInput, printOutput) - val capturedErrors = mutableListOf() + val capturedErrors = Channel(Channel.UNLIMITED) val receivedMessages = mutableListOf() + val mutex = Mutex() val secondMessageProcessed = CompletableDeferred() val message1 = PingRequest().toJSON() val message2 = InitializedNotification().toJSON() - server.onError { capturedErrors.add(it) } + server.onError { error -> + capturedErrors.trySend(error) + } server.onMessage { message -> - if (message == message1) { - throw throwable - } else { - receivedMessages.add(message) - secondMessageProcessed.complete(Unit) + mutex.withLock { + if (message == message1) { + throw throwable + } else { + receivedMessages.add(message) + secondMessageProcessed.complete(Unit) + } } } @@ -246,9 +261,17 @@ class StdioServerTransportTest { secondMessageProcessed.await() - capturedErrors shouldContain throwable - receivedMessages shouldBe listOf(message2) + val errors = mutableListOf() + eventually(2.seconds) { + while (true) { + val error = capturedErrors.tryReceive().getOrNull() ?: break + errors.add(error) + } + errors shouldContain throwable + } + server.close() + receivedMessages shouldBe listOf(message2) } @Test @@ -299,6 +322,43 @@ class StdioServerTransportTest { server.close() } + @Test + fun `should process messages concurrently with async handler`() = runIntegrationTest { + val server = StdioServerTransport(bufferedInput, printOutput) + val firstMessageReceived = CompletableDeferred() + val firstMessageProceed = CompletableDeferred() + val secondMessageHandled = CompletableDeferred() + + server.onMessage { message -> + // Simulate Protocol's new behavior: launch a coroutine for each message + launch { + if ((message as? JSONRPCRequest)?.method == "slow") { + firstMessageReceived.complete(Unit) + firstMessageProceed.await() + } else { + secondMessageHandled.complete(Unit) + } + } + } + server.start() + + val message1 = JSONRPCRequest(id = RequestId(1), method = "slow") + val message2 = JSONRPCRequest(id = RequestId(2), method = "fast") + + inputWriter.write(serializeMessage(message1)) + inputWriter.write(serializeMessage(message2)) + inputWriter.flush() + + firstMessageReceived.await() + // Wait for the second message to be processed while the first is still "working" + withTimeout(2.seconds) { + secondMessageHandled.await() + } + + firstMessageProceed.complete(Unit) + server.close() + } + @Suppress("unused") private fun inputErrors() = listOf( IOException("simulated read failure"), diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt index c6e539adb..99af46329 100644 --- a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransportTest.kt @@ -39,6 +39,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.RequestId import io.modelcontextprotocol.kotlin.sdk.types.Tool import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.serialization.builtins.ListSerializer import kotlinx.serialization.json.buildJsonObject import kotlinx.serialization.json.put @@ -214,11 +216,14 @@ class StreamableHttpServerTransportTest { val transport = StreamableHttpServerTransport(enableJsonResponse = true) val receivedMessages = mutableListOf() + val mutex = Mutex() transport.onMessage { message -> - if (message is JSONRPCRequest) { - transport.send(JSONRPCResponse(message.id, EmptyResult())) + mutex.withLock { + if (message is JSONRPCRequest) { + transport.send(JSONRPCResponse(message.id, EmptyResult())) + } + receivedMessages.add(message) } - receivedMessages.add(message) } configureTransportEndpoint(transport) diff --git a/kotlin-sdk-testing/api/kotlin-sdk-testing.api b/kotlin-sdk-testing/api/kotlin-sdk-testing.api index 4a942a4f0..6949739f9 100644 --- a/kotlin-sdk-testing/api/kotlin-sdk-testing.api +++ b/kotlin-sdk-testing/api/kotlin-sdk-testing.api @@ -13,6 +13,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport$C public final class io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport$LinkedTransports { public fun (Lio/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport;Lio/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport;)V + public final fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun component1 ()Lio/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport; public final fun component2 ()Lio/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport; public final fun copy (Lio/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport;Lio/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport;)Lio/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport$LinkedTransports; diff --git a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt index e6d19f7b3..02f94af23 100644 --- a/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt +++ b/kotlin-sdk-testing/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransport.kt @@ -10,17 +10,14 @@ import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.CoroutineName -import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers -import kotlinx.coroutines.Job -import kotlinx.coroutines.SupervisorJob -import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.channels.SendChannel import kotlinx.coroutines.launch import kotlinx.coroutines.yield +import kotlin.coroutines.CoroutineContext /** * A transport implementation that uses Kotlin Coroutines Channels for asynchronous @@ -37,12 +34,11 @@ public class ChannelTransport( private val sendChannel: SendChannel, private val receiveChannel: ReceiveChannel, dispatcher: CoroutineDispatcher = Dispatchers.Default, -) : AbstractClientTransport() { + handlerContext: CoroutineContext = dispatcher, +) : AbstractClientTransport(dispatcher, handlerContext) { override val logger: KLogger = KotlinLogging.logger {} - private val scope = CoroutineScope(SupervisorJob() + dispatcher) - /** * Creates a `ChannelTransport` instance using a single channel for both sending and receiving messages. * @@ -57,7 +53,8 @@ public class ChannelTransport( public constructor( channel: Channel = Channel(UNLIMITED), dispatcher: CoroutineDispatcher = Dispatchers.Default, - ) : this(channel, channel, dispatcher) + handlerContext: CoroutineContext = dispatcher, + ) : this(channel, channel, dispatcher, handlerContext) /** * Represents a pair of interconnected [ChannelTransport]s for bidirectional communication. @@ -68,7 +65,12 @@ public class ChannelTransport( * @property clientTransport The transport intended for use on the client-side. * @property serverTransport The transport intended for use on the server-side. */ - public data class LinkedTransports(val clientTransport: ChannelTransport, val serverTransport: ChannelTransport) + public data class LinkedTransports(val clientTransport: ChannelTransport, val serverTransport: ChannelTransport) { + public suspend fun close() { + clientTransport.close() + serverTransport.close() + } + } public companion object { @@ -85,11 +87,12 @@ public class ChannelTransport( public fun createLinkedPair( capacity: Int = 256, dispatcher: CoroutineDispatcher = Dispatchers.Default, + handlerContext: CoroutineContext = dispatcher, ): LinkedTransports { val sendChannel = Channel(capacity) val receiveChannel = Channel(capacity) - val clientTransport = ChannelTransport(sendChannel, receiveChannel, dispatcher) - val serverTransport = ChannelTransport(receiveChannel, sendChannel, dispatcher) + val clientTransport = ChannelTransport(sendChannel, receiveChannel, dispatcher, handlerContext) + val serverTransport = ChannelTransport(receiveChannel, sendChannel, dispatcher, handlerContext) return LinkedTransports(clientTransport = clientTransport, serverTransport = serverTransport) } } @@ -115,20 +118,24 @@ public class ChannelTransport( for (message in receiveChannel) { logger.debug { "Received message: ${message::class.simpleName}" } + @Suppress("InjectDispatcher") + handleMessage(message) + .invokeOnCompletion { + when (it) { + null -> logger.trace { "Message processed successfully: ${message::class.simpleName}" } + + is CancellationException -> logger.debug { + "Cancellation requested during message processing" + } - try { - _onMessage.invoke(message) - logger.trace { "Message processed successfully: ${message::class.simpleName}" } - } catch (e: CancellationException) { - // Let cancellation propagate immediately - logger.debug { "Cancellation requested during message processing" } - throw e - } catch (e: Exception) { - // Report other errors but continue processing - logger.warn(e) { "Error processing message: ${message::class.simpleName}" } - _onError.invoke(e) - } + else -> logger.warn(it) { + "Error processing message: ${message::class.simpleName}" + } + } + } } + logger.info { "ChannelTransport stopping: receive channel closed" } + joinInProgressHandlers() logger.info { "ChannelTransport stopped: receive channel closed" } } catch (e: Exception) { // Only complete exceptionally if not already completed @@ -169,8 +176,7 @@ public class ChannelTransport( logger.debug { "Cancelling separate receive channel" } receiveChannel.cancel() } - scope.cancel() - scope.coroutineContext[Job]?.join() // Wait for cleanup + shutdownHandlers() logger.info { "ChannelTransport closed" } } } diff --git a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt index 7fc41bb0e..8ffc15552 100644 --- a/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt +++ b/kotlin-sdk-testing/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/testing/ChannelTransportTest.kt @@ -1,7 +1,7 @@ package io.modelcontextprotocol.kotlin.sdk.testing import io.kotest.assertions.nondeterministic.eventually -import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder import io.kotest.matchers.shouldBe import io.kotest.matchers.types.shouldBeInstanceOf import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi @@ -12,7 +12,10 @@ import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.cancelAndJoin import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.ClosedSendChannelException +import kotlinx.coroutines.delay import kotlinx.coroutines.launch +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.test.runTest import kotlin.test.Test import kotlin.time.Duration.Companion.seconds @@ -40,10 +43,13 @@ class ChannelTransportTest { val messagesProcessed = CompletableDeferred() val received = mutableListOf() + val mutex = Mutex() transport.onMessage { msg -> - received.add(msg) - if (received.size == 2) { - messagesProcessed.complete(Unit) + mutex.withLock { + received.add(msg) + if (received.size == 2) { + messagesProcessed.complete(Unit) + } } } @@ -57,7 +63,7 @@ class ChannelTransportTest { // Wait for messages to be processed messagesProcessed.await() - received.shouldContainExactly(msg1, msg2) + received.shouldContainExactlyInAnyOrder(msg1, msg2) } @Test @@ -67,9 +73,12 @@ class ChannelTransportTest { val messageProcessed = CompletableDeferred() val received = mutableListOf() + val mutex = Mutex() transport.onMessage { - received.add(it) - messageProcessed.complete(Unit) + mutex.withLock { + received.add(it) + messageProcessed.complete(Unit) + } } transport.start() @@ -81,7 +90,7 @@ class ChannelTransportTest { // Wait for a message to be processed messageProcessed.await() - received.shouldContainExactly(msg) + received.shouldContainExactlyInAnyOrder(msg) } @Test @@ -106,12 +115,17 @@ class ChannelTransportTest { val messageProcessed = CompletableDeferred() val calls = mutableListOf() + val mutex = Mutex() transport.onMessage { - calls.add("first") + mutex.withLock { + calls.add("first") + } } transport.onMessage { - calls.add("second") - messageProcessed.complete(Unit) + mutex.withLock { + calls.add("second") + messageProcessed.complete(Unit) + } } val startJob = backgroundScope.launch { transport.start() } @@ -119,7 +133,7 @@ class ChannelTransportTest { receiveChannel.send(JSONRPCRequest(RequestId.NumberId(1), "test")) messageProcessed.await() - calls.shouldContainExactly("first", "second") + calls.shouldContainExactlyInAnyOrder("first", "second") startJob.cancelAndJoin() } @@ -130,9 +144,12 @@ class ChannelTransportTest { val messageProcessed = CompletableDeferred() val received = mutableListOf() + val mutex = Mutex() transport.onMessage { msg -> - received.add(msg) - messageProcessed.complete(Unit) + mutex.withLock { + received.add(msg) + messageProcessed.complete(Unit) + } } transport.start() @@ -142,7 +159,7 @@ class ChannelTransportTest { messageProcessed.await() eventually(2.seconds) { - received.shouldContainExactly(message) + received.shouldContainExactlyInAnyOrder(message) } } @@ -179,18 +196,23 @@ class ChannelTransportTest { val allProcessed = CompletableDeferred() val received = mutableListOf() - val errors = mutableListOf() + val errors = Channel(Channel.UNLIMITED) + val mutex = Mutex() - transport.onError { errors.add(it) } + transport.onError { error -> + errors.trySend(error) + } transport.onMessage { msg -> val id = ((msg as JSONRPCRequest).id as RequestId.NumberId).value.toInt() - received.add(id) - if (id == 2) { - @Suppress("TooGenericExceptionThrown") - throw RuntimeException("Error processing message 2") - } - if (received.size == 4) { - allProcessed.complete(Unit) + mutex.withLock { + received.add(id) + if (id == 2) { + @Suppress("TooGenericExceptionThrown") + throw RuntimeException("Error processing message 2") + } + if (received.size == 4) { + allProcessed.complete(Unit) + } } } @@ -199,14 +221,23 @@ class ChannelTransportTest { // Send 4 messages, second one will throw receiveChannel.send(JSONRPCRequest(RequestId.NumberId(1), "m1")) receiveChannel.send(JSONRPCRequest(RequestId.NumberId(2), "m2")) + delay(100) receiveChannel.send(JSONRPCRequest(RequestId.NumberId(3), "m3")) receiveChannel.send(JSONRPCRequest(RequestId.NumberId(4), "m4")) allProcessed.await() + transport.close() // All messages should be processed despite error in message 2 - received.shouldContainExactly(1, 2, 3, 4) - errors.size shouldBe 1 - errors[0].message shouldBe "Error processing message 2" + received.shouldContainExactlyInAnyOrder(1, 2, 3, 4) + val capturedErrors = mutableListOf() + eventually(2.seconds) { + while (true) { + val error = errors.tryReceive().getOrNull() ?: break + capturedErrors.add(error) + } + capturedErrors.size shouldBe 1 + capturedErrors[0].message shouldBe "Error processing message 2" + } } } diff --git a/test-utils/src/jvmMain/kotlin/io/modelcontextprotocol/kotlin/test/utils/TypeScriptRunner.kt b/test-utils/src/jvmMain/kotlin/io/modelcontextprotocol/kotlin/test/utils/TypeScriptRunner.kt index e7f5798f4..f5202398c 100644 --- a/test-utils/src/jvmMain/kotlin/io/modelcontextprotocol/kotlin/test/utils/TypeScriptRunner.kt +++ b/test-utils/src/jvmMain/kotlin/io/modelcontextprotocol/kotlin/test/utils/TypeScriptRunner.kt @@ -28,7 +28,11 @@ public object TypeScriptRunner { logPrefix: String = "TS-RUNNER", log: Boolean = true, ): Process { - val command = mutableListOf("npx", "tsx", scriptPath) + val command = if (isWindows) { + mutableListOf("cmd.exe", "/c", NPX, "tsx", scriptPath) + } else { + mutableListOf(NPX, "tsx", scriptPath) + } command.addAll(arguments) val pb = ProcessBuilder(command) pb.directory(typescriptDir)