Skip to content

Commit 5b88d56

Browse files
committed
fix: concurrent message processing for all transports
1 parent c136ad9 commit 5b88d56

32 files changed

Lines changed: 491 additions & 123 deletions

File tree

AGENTS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Follow these rules to keep changes safe, comprehensible, and easy to maintain.
7777
- **Prioritize test readability**
7878
- Avoid creating too many test methods; use parametrized tests when testing multiple similar scenarios
7979
- When running tests on Kotlin Multiplatform projects, run JVM tests only unless asked for other platforms
80+
- **Concurrency in Tests**: Always use thread-safe collections (e.g., `Mutex`-protected lists or `Channel`) when collecting messages from transports that process messages concurrently in the background (like those inheriting from `AbstractTransport`). Using non-thread-safe `MutableList` will lead to flaky tests or missing messages.
8081

8182
### Test Framework Stack
8283

@@ -169,6 +170,7 @@ prop.shouldNotBeNull {
169170
- Use Kotlinx Serialization with explicit `@Serializable` annotations
170171
- JSON config is defined in `jsonUtils.kt` as `McpJson` — use it consistently
171172
- Register custom serializers in companion objects
173+
- **SSE Data Concatenation**: When parsing Server-Sent Events (SSE) data, always ensure that multiple `data:` lines are concatenated with a newline (`\n`) separator, as per the SSE specification.
172174

173175
### Error Handling
174176

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.TextContent
3838
import io.modelcontextprotocol.kotlin.sdk.types.Tool
3939
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
4040
import kotlinx.coroutines.CompletableDeferred
41+
import kotlinx.coroutines.CoroutineScope
4142
import kotlinx.coroutines.TimeoutCancellationException
4243
import kotlinx.coroutines.cancel
4344
import kotlinx.coroutines.delay
@@ -62,6 +63,7 @@ class ClientTest {
6263
fun `should initialize with matching protocol version`() = runTest {
6364
var initialised = false
6465
val clientTransport = object : AbstractTransport() {
66+
override val scope: CoroutineScope = this@runTest
6567
override suspend fun start() {}
6668

6769
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -81,7 +83,7 @@ class ClientTest {
8183
result = result,
8284
)
8385

84-
_onMessage.invoke(response)
86+
handleMessage(response)
8587
}
8688

8789
override suspend fun close() {
@@ -108,6 +110,7 @@ class ClientTest {
108110
fun `should initialize with supported older protocol version`() = runTest {
109111
val oldVersion = SUPPORTED_PROTOCOL_VERSIONS[1]
110112
val clientTransport = object : AbstractTransport() {
113+
override val scope: CoroutineScope = this@runTest
111114
override suspend fun start() {}
112115

113116
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -127,7 +130,7 @@ class ClientTest {
127130
id = message.id,
128131
result = result,
129132
)
130-
_onMessage.invoke(response)
133+
handleMessage(response)
131134
}
132135

133136
override suspend fun close() {
@@ -157,6 +160,7 @@ class ClientTest {
157160
fun `should reject unsupported protocol version`() = runTest {
158161
var closed = false
159162
val clientTransport = object : AbstractTransport() {
163+
override val scope: CoroutineScope = this@runTest
160164
override suspend fun start() {}
161165

162166
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -177,7 +181,7 @@ class ClientTest {
177181
result = result,
178182
)
179183

180-
_onMessage.invoke(response)
184+
handleMessage(response)
181185
}
182186

183187
override suspend fun close() {
@@ -204,6 +208,7 @@ class ClientTest {
204208
fun `should reject due to non cancellation exception`() = runTest {
205209
var closed = false
206210
val failingTransport = object : AbstractTransport() {
211+
override val scope: CoroutineScope = this@runTest
207212
override suspend fun start() {}
208213

209214
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -238,6 +243,7 @@ class ClientTest {
238243
fun `should rethrow McpException as is`() = runTest {
239244
var closed = false
240245
val failingTransport = object : AbstractTransport() {
246+
override val scope: CoroutineScope = this@runTest
241247
override suspend fun start() {}
242248

243249
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -276,6 +282,7 @@ class ClientTest {
276282
fun `should rethrow StreamableHttpError as is`() = runTest {
277283
var closed = false
278284
val failingTransport = object : AbstractTransport() {
285+
override val scope: CoroutineScope = this@runTest
279286
override suspend fun start() {}
280287

281288
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotification
1010
import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotificationParams
1111
import io.modelcontextprotocol.kotlin.sdk.types.ToolListChangedNotification
1212
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
13+
import kotlinx.coroutines.sync.Mutex
14+
import kotlinx.coroutines.sync.withLock
1315
import kotlinx.coroutines.test.runTest
1416
import kotlin.test.BeforeTest
1517
import kotlin.test.Test
@@ -190,8 +192,11 @@ class InMemoryTransportTest {
190192
)
191193

192194
val receivedMessages = mutableListOf<JSONRPCMessage>()
195+
val mutex = Mutex()
193196
clientTransport.onMessage { msg ->
194-
receivedMessages.add(msg)
197+
mutex.withLock {
198+
receivedMessages.add(msg)
199+
}
195200
}
196201

197202
notifications.forEach { notification ->

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
77
import io.modelcontextprotocol.kotlin.sdk.types.PingRequest
88
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
99
import kotlinx.coroutines.CompletableDeferred
10+
import kotlinx.coroutines.sync.Mutex
11+
import kotlinx.coroutines.sync.withLock
1012
import kotlin.test.fail
1113
import kotlin.time.Duration.Companion.seconds
1214

@@ -47,12 +49,15 @@ abstract class BaseTransportTest {
4749
)
4850

4951
val readMessages = mutableListOf<JSONRPCMessage>()
52+
val mutex = Mutex()
5053
val finished = CompletableDeferred<Unit>()
5154

5255
transport.onMessage { message ->
53-
readMessages.add(message)
54-
if (message == messages.last()) {
55-
finished.complete(Unit)
56+
mutex.withLock {
57+
readMessages.add(message)
58+
if (message == messages.last()) {
59+
finished.complete(Unit)
60+
}
5661
}
5762
}
5863

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
package io.modelcontextprotocol.kotlin.sdk.shared
22

33
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
4+
import kotlinx.coroutines.CoroutineScope
5+
import kotlinx.coroutines.Dispatchers
6+
import kotlinx.coroutines.SupervisorJob
7+
import kotlinx.coroutines.cancel
8+
import kotlinx.coroutines.launch
49

510
/**
611
* In-memory transport for creating clients and servers that talk to each other within the same process.
712
*/
813
class InMemoryTransport : AbstractTransport() {
14+
@Suppress("InjectDispatcher")
15+
override val scope: CoroutineScope = CoroutineScope(Dispatchers.Default + SupervisorJob())
916
private var otherTransport: InMemoryTransport? = null
1017
private val messageQueue: MutableList<JSONRPCMessage> = mutableListOf()
1118

@@ -27,21 +34,25 @@ class InMemoryTransport : AbstractTransport() {
2734
// Process any messages that were queued before start was called
2835
while (messageQueue.isNotEmpty()) {
2936
messageQueue.removeFirstOrNull()?.let { message ->
30-
_onMessage.invoke(message) // todo?
37+
handleMessageInline(message) // todo?
3138
}
3239
}
3340
}
3441

3542
override suspend fun close() {
3643
val other = otherTransport
3744
otherTransport = null
45+
val inProgress = scope.launch { joinInProgressHandlers() }
3846
other?.close()
47+
inProgress.join()
48+
scope.cancel()
3949
invokeOnCloseCallback()
4050
}
4151

4252
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
4353
val other = checkNotNull(otherTransport) { "Not connected" }
4454

45-
other._onMessage.invoke(message)
55+
// necessary to propagate the caller's context - sometimes test, sometimes not
56+
other.handleMessageInline(message)
4657
}
4758
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@ import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
99
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
1010
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
1111
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
12+
import kotlinx.coroutines.CoroutineStart
1213
import kotlinx.coroutines.Dispatchers
14+
import kotlinx.coroutines.async
1315
import kotlinx.coroutines.delay
1416
import kotlinx.coroutines.launch
1517
import kotlinx.coroutines.runBlocking
1618
import kotlinx.coroutines.test.runTest
19+
import kotlinx.coroutines.withContext
1720
import kotlinx.serialization.json.JsonArray
1821
import kotlinx.serialization.json.JsonPrimitive
1922
import kotlinx.serialization.json.add
2023
import kotlinx.serialization.json.buildJsonArray
2124
import kotlinx.serialization.json.buildJsonObject
2225
import kotlinx.serialization.json.put
2326
import org.junit.jupiter.api.Test
27+
import org.junit.jupiter.params.ParameterizedTest
28+
import org.junit.jupiter.params.provider.ValueSource
2429
import java.text.DecimalFormat
2530
import java.text.DecimalFormatSymbols
2631
import java.util.Locale
@@ -158,14 +163,29 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
158163
put("description", "Delay in milliseconds")
159164
},
160165
)
166+
put(
167+
"blocking",
168+
buildJsonObject {
169+
put("type", "boolean")
170+
put("description", "Whether to block the thread while waiting")
171+
},
172+
)
161173
},
162174
),
163175
) { request ->
164176
val delay = (request.params.arguments?.get("delay") as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000
177+
val blocking = (request.params.arguments?.get("blocking") as? JsonPrimitive)?.content?.toBoolean() ?: false
165178

166179
// simulate slow operation
167-
runBlocking {
168-
delay(delay.toLong())
180+
181+
if (blocking) {
182+
@Suppress("RunBlockingInSuspendFunction")
183+
runBlocking { delay(delay.toLong()) }
184+
} else {
185+
@Suppress("InjectDispatcher")
186+
withContext(Dispatchers.Default) {
187+
delay(delay.toLong())
188+
}
169189
}
170190

171191
CallToolResult(
@@ -691,6 +711,45 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
691711
actualContent shouldEqualJson expectedContent
692712
}
693713

714+
@ParameterizedTest
715+
@ValueSource(booleans = [true, false])
716+
@Suppress("InjectDispatcher")
717+
fun testToolConcurrentProcessing(blocking: Boolean): Unit = runBlocking(Dispatchers.Default) {
718+
val delayMs = 1000
719+
val arguments = mapOf("delay" to delayMs, "blocking" to blocking)
720+
721+
val startTime = System.currentTimeMillis()
722+
723+
// Start a slow tool call
724+
val deferredSlow = async {
725+
client.callTool(slowToolName, arguments)
726+
}
727+
728+
// Give it a tiny bit of time to reach the server and start processing
729+
delay(50)
730+
731+
// Start a fast tool call
732+
val deferredFast = async(start = CoroutineStart.UNDISPATCHED) {
733+
client.callTool(testToolName, mapOf("text" to "fast"))
734+
}
735+
736+
val fastResult = deferredFast.await()
737+
val fastEndTime = System.currentTimeMillis()
738+
739+
// The fast tool should finish MUCH sooner than the slow tool's delay if processed concurrently
740+
assertTrue(
741+
fastEndTime - startTime < delayMs,
742+
"Fast tool should finish before the slow tool's delay (took ${fastEndTime - startTime}ms)",
743+
)
744+
745+
deferredSlow.await()
746+
val slowEndTime = System.currentTimeMillis()
747+
748+
assertTrue(slowEndTime - startTime >= delayMs, "Slow tool should take at least the specified delay")
749+
750+
assertEquals("Echo: fast", (fastResult.content.first() as TextContent).text)
751+
}
752+
694753
@Test
695754
fun testSpecialCharacters() {
696755
runBlocking(Dispatchers.IO) {

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import io.ktor.server.engine.EmbeddedServer
99
import io.ktor.server.engine.embeddedServer
1010
import io.ktor.server.plugins.contentnegotiation.ContentNegotiation
1111
import io.ktor.server.routing.routing
12+
import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi
1213
import io.modelcontextprotocol.kotlin.sdk.client.Client
1314
import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport
1415
import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport
@@ -17,8 +18,10 @@ import io.modelcontextprotocol.kotlin.sdk.server.Server
1718
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
1819
import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport
1920
import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport
21+
import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport.Configuration
2022
import io.modelcontextprotocol.kotlin.sdk.server.mcp
2123
import io.modelcontextprotocol.kotlin.sdk.server.mcpStreamableHttp
24+
import io.modelcontextprotocol.kotlin.sdk.testing.ChannelTransport
2225
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
2326
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
2427
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
@@ -40,6 +43,7 @@ import io.ktor.server.cio.CIO as ServerCIO
4043
import io.ktor.server.sse.SSE as ServerSSE
4144

4245
@Retry(times = 3)
46+
@OptIn(ExperimentalMcpApi::class)
4347
abstract class KotlinTestBase {
4448

4549
protected val host = "localhost"
@@ -48,9 +52,10 @@ abstract class KotlinTestBase {
4852
protected lateinit var server: Server
4953
protected lateinit var client: Client
5054
protected lateinit var serverEngine: EmbeddedServer<*, *>
55+
protected val channelTransports = lazy { ChannelTransport.createLinkedPair() }
5156

5257
// Transport selection
53-
protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP }
58+
protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP, CHANNEL }
5459

5560
protected open val transportKind: TransportKind = TransportKind.STDIO
5661

@@ -121,6 +126,13 @@ abstract class KotlinTestBase {
121126
)
122127
client.connect(transport)
123128
}
129+
130+
TransportKind.CHANNEL -> {
131+
client = Client(
132+
Implementation("test", "1.0"),
133+
)
134+
client.connect(channelTransports.value.clientTransport)
135+
}
124136
}
125137
}
126138

@@ -148,7 +160,7 @@ abstract class KotlinTestBase {
148160
// Create StreamableHTTP server transport
149161
// Using JSON response mode for simpler testing (no SSE session required)
150162
val transport = StreamableHttpServerTransport(
151-
StreamableHttpServerTransport.Configuration(
163+
Configuration(
152164
enableJsonResponse = true, // Use JSON response mode for testing
153165
),
154166
)
@@ -196,6 +208,10 @@ abstract class KotlinTestBase {
196208
server.createSession(serverTransport)
197209
}
198210
}
211+
212+
TransportKind.CHANNEL -> {
213+
runBlocking { server.createSession(channelTransports.value.serverTransport) }
214+
}
199215
}
200216
}
201217

@@ -250,6 +266,12 @@ abstract class KotlinTestBase {
250266
}
251267
}
252268
}
269+
270+
TransportKind.CHANNEL -> {
271+
if (channelTransports.isInitialized()) {
272+
runBlocking { channelTransports.value.close() }
273+
}
274+
}
253275
}
254276
}
255277
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.channel
2+
3+
import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest
4+
5+
// while this isn't a "production" transport, we still want to ensure that it has the correct behavior
6+
class ToolIntegrationTestChannel : AbstractToolIntegrationTest() {
7+
override val transportKind: TransportKind = TransportKind.CHANNEL
8+
}

0 commit comments

Comments
 (0)