Skip to content

Commit a2dbd83

Browse files
committed
fix: concurrent message processing for all transports
1 parent c136ad9 commit a2dbd83

File tree

21 files changed

+362
-72
lines changed

21 files changed

+362
-72
lines changed

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.TextContent
3838
import io.modelcontextprotocol.kotlin.sdk.types.Tool
3939
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
4040
import kotlinx.coroutines.CompletableDeferred
41+
import kotlinx.coroutines.CoroutineScope
4142
import kotlinx.coroutines.TimeoutCancellationException
4243
import kotlinx.coroutines.cancel
4344
import kotlinx.coroutines.delay
@@ -62,6 +63,7 @@ class ClientTest {
6263
fun `should initialize with matching protocol version`() = runTest {
6364
var initialised = false
6465
val clientTransport = object : AbstractTransport() {
66+
override val scope: CoroutineScope = this@runTest
6567
override suspend fun start() {}
6668

6769
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -81,7 +83,7 @@ class ClientTest {
8183
result = result,
8284
)
8385

84-
_onMessage.invoke(response)
86+
handleMessage(response)
8587
}
8688

8789
override suspend fun close() {
@@ -108,6 +110,7 @@ class ClientTest {
108110
fun `should initialize with supported older protocol version`() = runTest {
109111
val oldVersion = SUPPORTED_PROTOCOL_VERSIONS[1]
110112
val clientTransport = object : AbstractTransport() {
113+
override val scope: CoroutineScope = this@runTest
111114
override suspend fun start() {}
112115

113116
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -127,7 +130,7 @@ class ClientTest {
127130
id = message.id,
128131
result = result,
129132
)
130-
_onMessage.invoke(response)
133+
handleMessage(response)
131134
}
132135

133136
override suspend fun close() {
@@ -157,6 +160,7 @@ class ClientTest {
157160
fun `should reject unsupported protocol version`() = runTest {
158161
var closed = false
159162
val clientTransport = object : AbstractTransport() {
163+
override val scope: CoroutineScope = this@runTest
160164
override suspend fun start() {}
161165

162166
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -177,7 +181,7 @@ class ClientTest {
177181
result = result,
178182
)
179183

180-
_onMessage.invoke(response)
184+
handleMessage(response)
181185
}
182186

183187
override suspend fun close() {
@@ -204,6 +208,7 @@ class ClientTest {
204208
fun `should reject due to non cancellation exception`() = runTest {
205209
var closed = false
206210
val failingTransport = object : AbstractTransport() {
211+
override val scope: CoroutineScope = this@runTest
207212
override suspend fun start() {}
208213

209214
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -238,6 +243,7 @@ class ClientTest {
238243
fun `should rethrow McpException as is`() = runTest {
239244
var closed = false
240245
val failingTransport = object : AbstractTransport() {
246+
override val scope: CoroutineScope = this@runTest
241247
override suspend fun start() {}
242248

243249
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -276,6 +282,7 @@ class ClientTest {
276282
fun `should rethrow StreamableHttpError as is`() = runTest {
277283
var closed = false
278284
val failingTransport = object : AbstractTransport() {
285+
override val scope: CoroutineScope = this@runTest
279286
override suspend fun start() {}
280287

281288
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
package io.modelcontextprotocol.kotlin.sdk.shared
22

33
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
4+
import kotlinx.coroutines.CoroutineScope
5+
import kotlinx.coroutines.Dispatchers
6+
import kotlinx.coroutines.SupervisorJob
7+
import kotlinx.coroutines.cancel
8+
import kotlinx.coroutines.launch
49

510
/**
611
* In-memory transport for creating clients and servers that talk to each other within the same process.
712
*/
813
class InMemoryTransport : AbstractTransport() {
14+
@Suppress("InjectDispatcher")
15+
override val scope: CoroutineScope = CoroutineScope(Dispatchers.Default + SupervisorJob())
916
private var otherTransport: InMemoryTransport? = null
1017
private val messageQueue: MutableList<JSONRPCMessage> = mutableListOf()
1118

@@ -27,21 +34,25 @@ class InMemoryTransport : AbstractTransport() {
2734
// Process any messages that were queued before start was called
2835
while (messageQueue.isNotEmpty()) {
2936
messageQueue.removeFirstOrNull()?.let { message ->
30-
_onMessage.invoke(message) // todo?
37+
handleMessageInline(message) // todo?
3138
}
3239
}
3340
}
3441

3542
override suspend fun close() {
3643
val other = otherTransport
3744
otherTransport = null
45+
val inProgress = scope.launch { joinInProgressHandlers() }
3846
other?.close()
47+
inProgress.join()
48+
scope.cancel()
3949
invokeOnCloseCallback()
4050
}
4151

4252
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
4353
val other = checkNotNull(otherTransport) { "Not connected" }
4454

45-
other._onMessage.invoke(message)
55+
// necessary to propagate the caller's context - sometimes test, sometimes not
56+
other.handleMessageInline(message)
4657
}
4758
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@ import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
99
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
1010
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
1111
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
12+
import kotlinx.coroutines.CoroutineStart
1213
import kotlinx.coroutines.Dispatchers
14+
import kotlinx.coroutines.async
1315
import kotlinx.coroutines.delay
1416
import kotlinx.coroutines.launch
1517
import kotlinx.coroutines.runBlocking
1618
import kotlinx.coroutines.test.runTest
19+
import kotlinx.coroutines.withContext
1720
import kotlinx.serialization.json.JsonArray
1821
import kotlinx.serialization.json.JsonPrimitive
1922
import kotlinx.serialization.json.add
2023
import kotlinx.serialization.json.buildJsonArray
2124
import kotlinx.serialization.json.buildJsonObject
2225
import kotlinx.serialization.json.put
2326
import org.junit.jupiter.api.Test
27+
import org.junit.jupiter.params.ParameterizedTest
28+
import org.junit.jupiter.params.provider.ValueSource
2429
import java.text.DecimalFormat
2530
import java.text.DecimalFormatSymbols
2631
import java.util.Locale
@@ -158,14 +163,29 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
158163
put("description", "Delay in milliseconds")
159164
},
160165
)
166+
put(
167+
"blocking",
168+
buildJsonObject {
169+
put("type", "boolean")
170+
put("description", "Whether to block the thread while waiting")
171+
},
172+
)
161173
},
162174
),
163175
) { request ->
164176
val delay = (request.params.arguments?.get("delay") as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000
177+
val blocking = (request.params.arguments?.get("blocking") as? JsonPrimitive)?.content?.toBoolean() ?: false
165178

166179
// simulate slow operation
167-
runBlocking {
168-
delay(delay.toLong())
180+
181+
if (blocking) {
182+
@Suppress("RunBlockingInSuspendFunction")
183+
runBlocking { delay(delay.toLong()) }
184+
} else {
185+
@Suppress("InjectDispatcher")
186+
withContext(Dispatchers.Default) {
187+
delay(delay.toLong())
188+
}
169189
}
170190

171191
CallToolResult(
@@ -691,6 +711,45 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
691711
actualContent shouldEqualJson expectedContent
692712
}
693713

714+
@ParameterizedTest
715+
@ValueSource(booleans = [true, false])
716+
@Suppress("InjectDispatcher")
717+
fun testToolConcurrentProcessing(blocking: Boolean): Unit = runBlocking(Dispatchers.Default) {
718+
val delayMs = 1000
719+
val arguments = mapOf("delay" to delayMs, "blocking" to blocking)
720+
721+
val startTime = System.currentTimeMillis()
722+
723+
// Start a slow tool call
724+
val deferredSlow = async {
725+
client.callTool(slowToolName, arguments)
726+
}
727+
728+
// Give it a tiny bit of time to reach the server and start processing
729+
delay(50)
730+
731+
// Start a fast tool call
732+
val deferredFast = async(start = CoroutineStart.UNDISPATCHED) {
733+
client.callTool(testToolName, mapOf("text" to "fast"))
734+
}
735+
736+
val fastResult = deferredFast.await()
737+
val fastEndTime = System.currentTimeMillis()
738+
739+
// The fast tool should finish MUCH sooner than the slow tool's delay if processed concurrently
740+
assertTrue(
741+
fastEndTime - startTime < delayMs,
742+
"Fast tool should finish before the slow tool's delay (took ${fastEndTime - startTime}ms)",
743+
)
744+
745+
deferredSlow.await()
746+
val slowEndTime = System.currentTimeMillis()
747+
748+
assertTrue(slowEndTime - startTime >= delayMs, "Slow tool should take at least the specified delay")
749+
750+
assertEquals("Echo: fast", (fastResult.content.first() as TextContent).text)
751+
}
752+
694753
@Test
695754
fun testSpecialCharacters() {
696755
runBlocking(Dispatchers.IO) {

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import io.ktor.server.engine.EmbeddedServer
99
import io.ktor.server.engine.embeddedServer
1010
import io.ktor.server.plugins.contentnegotiation.ContentNegotiation
1111
import io.ktor.server.routing.routing
12+
import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi
1213
import io.modelcontextprotocol.kotlin.sdk.client.Client
1314
import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport
1415
import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport
@@ -17,8 +18,10 @@ import io.modelcontextprotocol.kotlin.sdk.server.Server
1718
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
1819
import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport
1920
import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport
21+
import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport.Configuration
2022
import io.modelcontextprotocol.kotlin.sdk.server.mcp
2123
import io.modelcontextprotocol.kotlin.sdk.server.mcpStreamableHttp
24+
import io.modelcontextprotocol.kotlin.sdk.testing.ChannelTransport
2225
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
2326
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
2427
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
@@ -40,6 +43,7 @@ import io.ktor.server.cio.CIO as ServerCIO
4043
import io.ktor.server.sse.SSE as ServerSSE
4144

4245
@Retry(times = 3)
46+
@OptIn(ExperimentalMcpApi::class)
4347
abstract class KotlinTestBase {
4448

4549
protected val host = "localhost"
@@ -48,9 +52,10 @@ abstract class KotlinTestBase {
4852
protected lateinit var server: Server
4953
protected lateinit var client: Client
5054
protected lateinit var serverEngine: EmbeddedServer<*, *>
55+
protected val channelTransports = lazy { ChannelTransport.createLinkedPair() }
5156

5257
// Transport selection
53-
protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP }
58+
protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP, CHANNEL }
5459

5560
protected open val transportKind: TransportKind = TransportKind.STDIO
5661

@@ -121,6 +126,13 @@ abstract class KotlinTestBase {
121126
)
122127
client.connect(transport)
123128
}
129+
130+
TransportKind.CHANNEL -> {
131+
client = Client(
132+
Implementation("test", "1.0"),
133+
)
134+
client.connect(channelTransports.value.clientTransport)
135+
}
124136
}
125137
}
126138

@@ -148,7 +160,7 @@ abstract class KotlinTestBase {
148160
// Create StreamableHTTP server transport
149161
// Using JSON response mode for simpler testing (no SSE session required)
150162
val transport = StreamableHttpServerTransport(
151-
StreamableHttpServerTransport.Configuration(
163+
Configuration(
152164
enableJsonResponse = true, // Use JSON response mode for testing
153165
),
154166
)
@@ -196,6 +208,10 @@ abstract class KotlinTestBase {
196208
server.createSession(serverTransport)
197209
}
198210
}
211+
212+
TransportKind.CHANNEL -> {
213+
runBlocking { server.createSession(channelTransports.value.serverTransport) }
214+
}
199215
}
200216
}
201217

@@ -250,6 +266,12 @@ abstract class KotlinTestBase {
250266
}
251267
}
252268
}
269+
270+
TransportKind.CHANNEL -> {
271+
if (channelTransports.isInitialized()) {
272+
runBlocking { channelTransports.value.close() }
273+
}
274+
}
253275
}
254276
}
255277
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.channel
2+
3+
import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest
4+
5+
// while this isn't a "production" transport, we still want to ensure that it has the correct behavior
6+
class ToolIntegrationTestChannel : AbstractToolIntegrationTest() {
7+
override val transportKind: TransportKind = TransportKind.CHANNEL
8+
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/typescript/sse/KotlinServerForTsClientSse.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents
4242
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
4343
import kotlinx.coroutines.CancellationException
4444
import kotlinx.coroutines.CompletableDeferred
45+
import kotlinx.coroutines.CoroutineScope
46+
import kotlinx.coroutines.Dispatchers
47+
import kotlinx.coroutines.SupervisorJob
4548
import kotlinx.coroutines.channels.Channel
4649
import kotlinx.coroutines.runBlocking
4750
import kotlinx.coroutines.withTimeoutOrNull
@@ -310,6 +313,8 @@ class KotlinServerForTsClient {
310313
}
311314

312315
class HttpServerTransport(private val sessionId: String) : AbstractTransport() {
316+
@Suppress("InjectDispatcher")
317+
override val scope: CoroutineScope = CoroutineScope(Dispatchers.Default + SupervisorJob())
313318
private val logger = KotlinLogging.logger {}
314319
private val pendingResponses = ConcurrentHashMap<String, CompletableDeferred<JSONRPCMessage>>()
315320
private val messageQueue = Channel<JSONRPCMessage>(Channel.UNLIMITED)
@@ -352,7 +357,7 @@ class HttpServerTransport(private val sessionId: String) : AbstractTransport() {
352357
logger.info { "Created deferred response for ID: $id" }
353358

354359
logger.info { "Invoking onMessage handler" }
355-
_onMessage.invoke(message)
360+
handleMessage(message)
356361
logger.info { "onMessage handler completed" }
357362

358363
try {

integration-test/src/jvmTest/typescript/package-lock.json

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public class SseClientTransport(
5050
private val endpoint = CompletableDeferred<String>()
5151

5252
private lateinit var session: ClientSSESession
53-
private lateinit var scope: CoroutineScope
53+
override lateinit var scope: CoroutineScope
5454
private var job: Job? = null
5555

5656
private val origin: String by lazy {
@@ -161,10 +161,10 @@ public class SseClientTransport(
161161
}
162162
}
163163

164-
private suspend fun handleMessage(data: String) {
164+
private fun handleMessage(data: String) {
165165
try {
166166
val message = McpJson.decodeFromString<JSONRPCMessage>(data)
167-
_onMessage(message)
167+
handleMessage(message)
168168
} catch (e: SerializationException) {
169169
_onError(e)
170170
}
@@ -173,6 +173,7 @@ public class SseClientTransport(
173173
override suspend fun closeResources() {
174174
job?.cancelAndJoin()
175175
try {
176+
joinInProgressHandlers()
176177
if (::session.isInitialized) session.cancel()
177178
if (::scope.isInitialized) scope.cancel()
178179
endpoint.cancel()

0 commit comments

Comments
 (0)