Skip to content

Commit 848d6e7

Browse files
committed
fix: concurrent message processing for all transports
1 parent c136ad9 commit 848d6e7

35 files changed

Lines changed: 633 additions & 250 deletions

File tree

AGENTS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ Follow these rules to keep changes safe, comprehensible, and easy to maintain.
7777
- **Prioritize test readability**
7878
- Avoid creating too many test methods; use parametrized tests when testing multiple similar scenarios
7979
- When running tests on Kotlin Multiplatform projects, run JVM tests only unless asked for other platforms
80+
- **Concurrency in Tests**: Always use thread-safe collections (e.g., `Mutex`-protected lists or `Channel`) when collecting messages from transports that process messages concurrently in the background (like those inheriting from `AbstractTransport`). Using non-thread-safe `MutableList` will lead to flaky tests or missing messages.
8081

8182
### Test Framework Stack
8283

@@ -169,6 +170,7 @@ prop.shouldNotBeNull {
169170
- Use Kotlinx Serialization with explicit `@Serializable` annotations
170171
- JSON config is defined in `jsonUtils.kt` as `McpJson` — use it consistently
171172
- Register custom serializers in companion objects
173+
- **SSE Data Concatenation**: When parsing Server-Sent Events (SSE) data, always ensure that multiple `data:` lines are concatenated with a newline (`\n`) separator, as per the SSE specification.
172174

173175
### Error Handling
174176

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class ClientTest {
6161
@Test
6262
fun `should initialize with matching protocol version`() = runTest {
6363
var initialised = false
64-
val clientTransport = object : AbstractTransport() {
64+
val clientTransport = object : AbstractTransport(backgroundScope.coroutineContext) {
6565
override suspend fun start() {}
6666

6767
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -81,7 +81,7 @@ class ClientTest {
8181
result = result,
8282
)
8383

84-
_onMessage.invoke(response)
84+
handleMessage(response)
8585
}
8686

8787
override suspend fun close() {
@@ -107,7 +107,7 @@ class ClientTest {
107107
@Test
108108
fun `should initialize with supported older protocol version`() = runTest {
109109
val oldVersion = SUPPORTED_PROTOCOL_VERSIONS[1]
110-
val clientTransport = object : AbstractTransport() {
110+
val clientTransport = object : AbstractTransport(backgroundScope.coroutineContext) {
111111
override suspend fun start() {}
112112

113113
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -127,7 +127,7 @@ class ClientTest {
127127
id = message.id,
128128
result = result,
129129
)
130-
_onMessage.invoke(response)
130+
handleMessage(response)
131131
}
132132

133133
override suspend fun close() {
@@ -156,7 +156,7 @@ class ClientTest {
156156
@Test
157157
fun `should reject unsupported protocol version`() = runTest {
158158
var closed = false
159-
val clientTransport = object : AbstractTransport() {
159+
val clientTransport = object : AbstractTransport(backgroundScope.coroutineContext) {
160160
override suspend fun start() {}
161161

162162
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -177,7 +177,7 @@ class ClientTest {
177177
result = result,
178178
)
179179

180-
_onMessage.invoke(response)
180+
handleMessage(response)
181181
}
182182

183183
override suspend fun close() {
@@ -203,7 +203,7 @@ class ClientTest {
203203
@Test
204204
fun `should reject due to non cancellation exception`() = runTest {
205205
var closed = false
206-
val failingTransport = object : AbstractTransport() {
206+
val failingTransport = object : AbstractTransport(backgroundScope.coroutineContext) {
207207
override suspend fun start() {}
208208

209209
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -237,7 +237,7 @@ class ClientTest {
237237
@Test
238238
fun `should rethrow McpException as is`() = runTest {
239239
var closed = false
240-
val failingTransport = object : AbstractTransport() {
240+
val failingTransport = object : AbstractTransport(backgroundScope.coroutineContext) {
241241
override suspend fun start() {}
242242

243243
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
@@ -275,7 +275,7 @@ class ClientTest {
275275
@Test
276276
fun `should rethrow StreamableHttpError as is`() = runTest {
277277
var closed = false
278-
val failingTransport = object : AbstractTransport() {
278+
val failingTransport = object : AbstractTransport(backgroundScope.coroutineContext) {
279279
override suspend fun start() {}
280280

281281
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/InMemoryTransportTest.kt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotification
1010
import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotificationParams
1111
import io.modelcontextprotocol.kotlin.sdk.types.ToolListChangedNotification
1212
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
13+
import kotlinx.coroutines.sync.Mutex
14+
import kotlinx.coroutines.sync.withLock
1315
import kotlinx.coroutines.test.runTest
1416
import kotlin.test.BeforeTest
1517
import kotlin.test.Test
@@ -190,8 +192,11 @@ class InMemoryTransportTest {
190192
)
191193

192194
val receivedMessages = mutableListOf<JSONRPCMessage>()
195+
val mutex = Mutex()
193196
clientTransport.onMessage { msg ->
194-
receivedMessages.add(msg)
197+
mutex.withLock {
198+
receivedMessages.add(msg)
199+
}
195200
}
196201

197202
notifications.forEach { notification ->

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/BaseTransportTest.kt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
package io.modelcontextprotocol.kotlin.sdk.shared
22

33
import io.kotest.assertions.nondeterministic.eventually
4+
import io.kotest.matchers.collections.shouldContainExactlyInAnyOrder
45
import io.kotest.matchers.shouldBe
56
import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification
67
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
78
import io.modelcontextprotocol.kotlin.sdk.types.PingRequest
89
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
910
import kotlinx.coroutines.CompletableDeferred
11+
import kotlinx.coroutines.sync.Mutex
12+
import kotlinx.coroutines.sync.withLock
1013
import kotlin.test.fail
1114
import kotlin.time.Duration.Companion.seconds
1215

@@ -47,12 +50,15 @@ abstract class BaseTransportTest {
4750
)
4851

4952
val readMessages = mutableListOf<JSONRPCMessage>()
53+
val mutex = Mutex()
5054
val finished = CompletableDeferred<Unit>()
5155

5256
transport.onMessage { message ->
53-
readMessages.add(message)
54-
if (message == messages.last()) {
55-
finished.complete(Unit)
57+
mutex.withLock {
58+
readMessages.add(message)
59+
if (readMessages.size == messages.size) {
60+
finished.complete(Unit)
61+
}
5662
}
5763
}
5864

@@ -64,7 +70,7 @@ abstract class BaseTransportTest {
6470

6571
finished.await()
6672

67-
messages shouldBe readMessages
73+
readMessages.shouldContainExactlyInAnyOrder(messages)
6874

6975
transport.close()
7076
}

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/InMemoryTransport.kt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,23 @@ class InMemoryTransport : AbstractTransport() {
2727
// Process any messages that were queued before start was called
2828
while (messageQueue.isNotEmpty()) {
2929
messageQueue.removeFirstOrNull()?.let { message ->
30-
_onMessage.invoke(message) // todo?
30+
handleMessageInline(message) // todo?
3131
}
3232
}
3333
}
3434

3535
override suspend fun close() {
3636
val other = otherTransport
3737
otherTransport = null
38+
shutdownHandlers()
3839
other?.close()
3940
invokeOnCloseCallback()
4041
}
4142

4243
override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) {
4344
val other = checkNotNull(otherTransport) { "Not connected" }
4445

45-
other._onMessage.invoke(message)
46+
// necessary to propagate the caller's context - sometimes test, sometimes not
47+
other.handleMessageInline(message)
4648
}
4749
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/AbstractAuthenticationTest.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ abstract class AbstractAuthenticationTest {
113113
var mcpClient: Client? = null
114114
try {
115115
mcpClient = Client(Implementation(name = "test-client", version = "1.0.0"))
116-
withTimeout(5.seconds) {
116+
withTimeout(10.seconds) {
117117
mcpClient.connect(createClientTransport(baseUrl, VALID_USER, VALID_PASSWORD))
118118
}
119119

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,23 @@ import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
99
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
1010
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
1111
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
12+
import kotlinx.coroutines.CoroutineStart
1213
import kotlinx.coroutines.Dispatchers
14+
import kotlinx.coroutines.async
1315
import kotlinx.coroutines.delay
1416
import kotlinx.coroutines.launch
1517
import kotlinx.coroutines.runBlocking
1618
import kotlinx.coroutines.test.runTest
19+
import kotlinx.coroutines.withContext
1720
import kotlinx.serialization.json.JsonArray
1821
import kotlinx.serialization.json.JsonPrimitive
1922
import kotlinx.serialization.json.add
2023
import kotlinx.serialization.json.buildJsonArray
2124
import kotlinx.serialization.json.buildJsonObject
2225
import kotlinx.serialization.json.put
2326
import org.junit.jupiter.api.Test
27+
import org.junit.jupiter.params.ParameterizedTest
28+
import org.junit.jupiter.params.provider.ValueSource
2429
import java.text.DecimalFormat
2530
import java.text.DecimalFormatSymbols
2631
import java.util.Locale
@@ -158,14 +163,29 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
158163
put("description", "Delay in milliseconds")
159164
},
160165
)
166+
put(
167+
"blocking",
168+
buildJsonObject {
169+
put("type", "boolean")
170+
put("description", "Whether to block the thread while waiting")
171+
},
172+
)
161173
},
162174
),
163175
) { request ->
164176
val delay = (request.params.arguments?.get("delay") as? JsonPrimitive)?.content?.toIntOrNull() ?: 1000
177+
val blocking = (request.params.arguments?.get("blocking") as? JsonPrimitive)?.content?.toBoolean() ?: false
165178

166179
// simulate slow operation
167-
runBlocking {
168-
delay(delay.toLong())
180+
181+
if (blocking) {
182+
@Suppress("RunBlockingInSuspendFunction")
183+
runBlocking { delay(delay.toLong()) }
184+
} else {
185+
@Suppress("InjectDispatcher")
186+
withContext(Dispatchers.Default) {
187+
delay(delay.toLong())
188+
}
169189
}
170190

171191
CallToolResult(
@@ -691,6 +711,45 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
691711
actualContent shouldEqualJson expectedContent
692712
}
693713

714+
@ParameterizedTest
715+
@ValueSource(booleans = [true, false])
716+
@Suppress("InjectDispatcher")
717+
fun testToolConcurrentProcessing(blocking: Boolean): Unit = runBlocking(Dispatchers.Default) {
718+
val delayMs = 1000
719+
val arguments = mapOf("delay" to delayMs, "blocking" to blocking)
720+
721+
val startTime = System.currentTimeMillis()
722+
723+
// Start a slow tool call
724+
val deferredSlow = async {
725+
client.callTool(slowToolName, arguments)
726+
}
727+
728+
// Give it a tiny bit of time to reach the server and start processing
729+
delay(50)
730+
731+
// Start a fast tool call
732+
val deferredFast = async(start = CoroutineStart.UNDISPATCHED) {
733+
client.callTool(testToolName, mapOf("text" to "fast"))
734+
}
735+
736+
val fastResult = deferredFast.await()
737+
val fastEndTime = System.currentTimeMillis()
738+
739+
// The fast tool should finish MUCH sooner than the slow tool's delay if processed concurrently
740+
assertTrue(
741+
fastEndTime - startTime < delayMs,
742+
"Fast tool should finish before the slow tool's delay (took ${fastEndTime - startTime}ms)",
743+
)
744+
745+
deferredSlow.await()
746+
val slowEndTime = System.currentTimeMillis()
747+
748+
assertTrue(slowEndTime - startTime >= delayMs, "Slow tool should take at least the specified delay")
749+
750+
assertEquals("Echo: fast", (fastResult.content.first() as TextContent).text)
751+
}
752+
694753
@Test
695754
fun testSpecialCharacters() {
696755
runBlocking(Dispatchers.IO) {

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/KotlinTestBase.kt

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import io.ktor.server.engine.EmbeddedServer
99
import io.ktor.server.engine.embeddedServer
1010
import io.ktor.server.plugins.contentnegotiation.ContentNegotiation
1111
import io.ktor.server.routing.routing
12+
import io.modelcontextprotocol.kotlin.sdk.ExperimentalMcpApi
1213
import io.modelcontextprotocol.kotlin.sdk.client.Client
1314
import io.modelcontextprotocol.kotlin.sdk.client.SseClientTransport
1415
import io.modelcontextprotocol.kotlin.sdk.client.StdioClientTransport
@@ -17,8 +18,10 @@ import io.modelcontextprotocol.kotlin.sdk.server.Server
1718
import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions
1819
import io.modelcontextprotocol.kotlin.sdk.server.StdioServerTransport
1920
import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport
21+
import io.modelcontextprotocol.kotlin.sdk.server.StreamableHttpServerTransport.Configuration
2022
import io.modelcontextprotocol.kotlin.sdk.server.mcp
2123
import io.modelcontextprotocol.kotlin.sdk.server.mcpStreamableHttp
24+
import io.modelcontextprotocol.kotlin.sdk.testing.ChannelTransport
2225
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
2326
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
2427
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
@@ -40,6 +43,7 @@ import io.ktor.server.cio.CIO as ServerCIO
4043
import io.ktor.server.sse.SSE as ServerSSE
4144

4245
@Retry(times = 3)
46+
@OptIn(ExperimentalMcpApi::class)
4347
abstract class KotlinTestBase {
4448

4549
protected val host = "localhost"
@@ -48,9 +52,10 @@ abstract class KotlinTestBase {
4852
protected lateinit var server: Server
4953
protected lateinit var client: Client
5054
protected lateinit var serverEngine: EmbeddedServer<*, *>
55+
protected val channelTransports = lazy { ChannelTransport.createLinkedPair() }
5156

5257
// Transport selection
53-
protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP }
58+
protected enum class TransportKind { SSE, STDIO, STREAMABLE_HTTP, CHANNEL }
5459

5560
protected open val transportKind: TransportKind = TransportKind.STDIO
5661

@@ -121,6 +126,13 @@ abstract class KotlinTestBase {
121126
)
122127
client.connect(transport)
123128
}
129+
130+
TransportKind.CHANNEL -> {
131+
client = Client(
132+
Implementation("test", "1.0"),
133+
)
134+
client.connect(channelTransports.value.clientTransport)
135+
}
124136
}
125137
}
126138

@@ -148,7 +160,7 @@ abstract class KotlinTestBase {
148160
// Create StreamableHTTP server transport
149161
// Using JSON response mode for simpler testing (no SSE session required)
150162
val transport = StreamableHttpServerTransport(
151-
StreamableHttpServerTransport.Configuration(
163+
Configuration(
152164
enableJsonResponse = true, // Use JSON response mode for testing
153165
),
154166
)
@@ -196,6 +208,10 @@ abstract class KotlinTestBase {
196208
server.createSession(serverTransport)
197209
}
198210
}
211+
212+
TransportKind.CHANNEL -> {
213+
runBlocking { server.createSession(channelTransports.value.serverTransport) }
214+
}
199215
}
200216
}
201217

@@ -250,6 +266,12 @@ abstract class KotlinTestBase {
250266
}
251267
}
252268
}
269+
270+
TransportKind.CHANNEL -> {
271+
if (channelTransports.isInitialized()) {
272+
runBlocking { channelTransports.value.close() }
273+
}
274+
}
253275
}
254276
}
255277
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.channel
2+
3+
import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.AbstractToolIntegrationTest
4+
5+
// while this isn't a "production" transport, we still want to ensure that it has the correct behavior
6+
class ToolIntegrationTestChannel : AbstractToolIntegrationTest() {
7+
override val transportKind: TransportKind = TransportKind.CHANNEL
8+
}

0 commit comments

Comments
 (0)