Skip to content

Commit a388180

Browse files
authored
fix!(server): #575 Rethrow CancellationException in StdioServerTransport, fix ReadBuffer (#571)
## Rethrow CancellationException in StdioServerTransport, fix ReadBuffer ### Changes 1. StdioServerTransport - Replace generic `Throwable` catches with specific `CancellationException` and `Exception` for clarity and correctness. **breaking change** `Throwable` is not handled any more! - Refactor launching jobs to separate methods - Introduce `READ_BUFFER_SIZE` constant to replace inline buffer size literals. - Add suppress annotations for clearer intent and constructor documentation for better usability. - Extend StdioServerTransportTest 2. ReadBuffer Previously, readMessage() returned null after consuming an unparseable line even when more complete lines existed in the buffer. This caused valid messages following a bad line in the same chunk to be silently dropped until the next chunk arrived (or forever, in tests). Fix the loop so null means only "no complete line available", not "encountered a parse failure". Blank/whitespace-only lines are now silently skipped via isBlank() rather than forwarded to the deserializer and logged as errors. Refactor the method into three focused helpers: - readMessage() — outer loop over lines - readLine() — consume the next newline-delimited line from the buffer - tryRecover() — attempt deserialization from the first '{' onward 3. Add utility method `runIntegrationTest` for integration testing (`runBlocking` + `withTimeout`). Use it in StdioServerTransportTest 4. Store processing Job reference and `calcelAndJoin()` it on close. Partially addressing #574 ## Motivation and Context See #575, #564, #242 ## How Has This Been Tested? Unit test added ## Breaking Changes Semantic change: `Throwable` and CancellationException are not handled any more! ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update ## Checklist <!-- Go over all the following points, and put an `x` in all the boxes that apply. --> - [x] I have read the [MCP Documentation](https://modelcontextprotocol.io) - [x] My code follows the repository's style guidelines - [x] New and existing tests pass locally - [ ] I have added appropriate error handling - [ ] I have added or updated documentation as needed ## Additional context <!-- Add any other context, implementation notes, or design decisions -->
1 parent 9eabb54 commit a388180

6 files changed

Lines changed: 411 additions & 72 deletions

File tree

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBuffer.kt

Lines changed: 56 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -20,45 +20,70 @@ public class ReadBuffer {
2020
buffer.write(chunk)
2121
}
2222

23+
/**
24+
* Reads and deserializes a JSON-RPC message from the input buffer.
25+
*
26+
* The method attempts to read lines from the buffer until a valid `JSONRPCMessage` is successfully
27+
* deserialized. Blank lines are ignored, and if a deserialization error occurs, the method attempts
28+
* to recover and process the message.
29+
*
30+
* Recovery involves attempting to parse from the first detected JSON object in the line when an
31+
* error is encountered during deserialization.
32+
*
33+
* @return A deserialized `JSONRPCMessage` if successfully processed; otherwise, `null` if the
34+
* input buffer is exhausted or no valid message is found.
35+
*/
2336
public fun readMessage(): JSONRPCMessage? {
24-
if (buffer.exhausted()) return null
25-
var lfIndex = buffer.indexOf('\n'.code.toByte())
26-
val line = when (lfIndex) {
27-
-1L -> return null
37+
while (true) {
38+
val line = readNextLine() ?: return null
39+
if (line.isBlank()) continue
2840

29-
0L -> {
30-
buffer.skip(1)
31-
return null
41+
@Suppress("TooGenericExceptionCaught")
42+
val message = try {
43+
deserializeMessage(line)
44+
} catch (e: Exception) {
45+
logger.error(e) { "Failed to deserialize message from line: $line\nAttempting to recover..." }
46+
tryRecover(line)
3247
}
33-
34-
else -> {
35-
var skipBytes = 1
36-
if (buffer[lfIndex - 1] == '\r'.code.toByte()) {
37-
lfIndex -= 1
38-
skipBytes += 1
39-
}
40-
val string = buffer.readString(lfIndex)
41-
buffer.skip(skipBytes.toLong())
42-
string
48+
if (message != null) {
49+
return message
4350
}
4451
}
45-
try {
46-
return deserializeMessage(line)
47-
} catch (e: Exception) {
48-
logger.error(e) { "Failed to deserialize message from line: $line\nAttempting to recover..." }
49-
// if there is a non-JSON object prefix, try to parse from the first '{' onward.
50-
val braceIndex = line.indexOf('{')
51-
if (braceIndex != -1) {
52-
val trimmed = line.substring(braceIndex)
53-
try {
54-
return deserializeMessage(trimmed)
55-
} catch (ignored: Exception) {
56-
logger.error(ignored) { "Deserialization failed for line: $line\nSkipping..." }
57-
}
52+
}
53+
54+
private fun readNextLine(): String? {
55+
val lfIndex = if (buffer.exhausted()) -1L else buffer.indexOf('\n'.code.toByte())
56+
if (lfIndex == -1L) return null
57+
58+
return if (lfIndex == 0L) {
59+
buffer.skip(1)
60+
""
61+
} else {
62+
var skipBytes = 1
63+
var messageLength = lfIndex
64+
if (buffer[lfIndex - 1] == '\r'.code.toByte()) {
65+
messageLength -= 1
66+
skipBytes += 1
5867
}
68+
val string = buffer.readString(messageLength)
69+
buffer.skip(skipBytes.toLong())
70+
string
5971
}
72+
}
73+
74+
private fun tryRecover(line: String): JSONRPCMessage? {
75+
// if there is a non-JSON object prefix, try to parse from the first '{' onward.
76+
val braceIndex = line.indexOf('{')
77+
if (braceIndex == -1) return null
6078

61-
return null
79+
val trimmed = line.substring(braceIndex)
80+
@Suppress("TooGenericExceptionCaught")
81+
return try {
82+
deserializeMessage(trimmed)
83+
} catch (ignored: Exception) {
84+
logger.error(ignored) { "Deserialization failed for line: $line\nSkipping..." }
85+
null
86+
}
6287
}
6388

6489
public fun clear() {

kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ReadBufferTest.kt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ class ReadBufferTest {
3131
val messageBytes = json.encodeToString(testMessage).encodeToByteArray()
3232
readBuffer.append(messageBytes)
3333
assertNull(readBuffer.readMessage())
34+
readBuffer.append("\r".encodeToByteArray())
35+
assertNull(readBuffer.readMessage())
3436

3537
// Append a newline and verify message is now available
3638
readBuffer.append("\n".encodeToByteArray())
@@ -45,6 +47,20 @@ class ReadBufferTest {
4547
assertNull(readBuffer.readMessage())
4648
}
4749

50+
@Test
51+
fun `skip blank line`() {
52+
val readBuffer = ReadBuffer()
53+
readBuffer.append(" \n".toByteArray())
54+
assertNull(readBuffer.readMessage())
55+
}
56+
57+
@Test
58+
fun `skip invalid json line`() {
59+
val readBuffer = ReadBuffer()
60+
readBuffer.append(" {ah=oh\n".toByteArray())
61+
assertNull(readBuffer.readMessage())
62+
}
63+
4864
@Test
4965
fun `should be reusable after clearing`() {
5066
val readBuffer = ReadBuffer()

kotlin-sdk-server/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@ kotlin {
2323
implementation(libs.kotest.assertions.core)
2424
implementation(libs.kotest.assertions.json)
2525
implementation(libs.kotlinx.coroutines.test)
26+
implementation(project(":test-utils"))
2627
}
2728
}
2829

2930
jvmTest {
3031
dependencies {
32+
implementation(libs.junit.jupiter.params)
3133
implementation(libs.kotest.assertions.ktor)
3234
implementation(libs.ktor.client.content.negotiation)
3335
implementation(libs.ktor.client.logging)

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt

Lines changed: 79 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer
77
import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions
88
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
99
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
10+
import kotlinx.coroutines.CancellationException
1011
import kotlinx.coroutines.CoroutineScope
1112
import kotlinx.coroutines.Job
1213
import kotlinx.coroutines.NonCancellable
@@ -26,37 +27,55 @@ import kotlin.concurrent.atomics.AtomicBoolean
2627
import kotlin.concurrent.atomics.ExperimentalAtomicApi
2728
import kotlin.coroutines.CoroutineContext
2829

30+
private const val READ_BUFFER_SIZE = 8192L
31+
2932
/**
3033
* A server transport that communicates with a client via standard I/O.
3134
*
32-
* Reads from System.in and writes to System.out.
35+
* Reads from input [Source] and writes to output [Sink].
36+
*
37+
* @constructor Creates a new instance of [StdioServerTransport].
38+
* @param inputStream The input [Source] used to receive data.
39+
* @param outputStream The output [Sink] used to send data.
3340
*/
3441
@OptIn(ExperimentalAtomicApi::class)
3542
public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() {
36-
private val logger = KotlinLogging.logger {}
3743

44+
private val logger = KotlinLogging.logger {}
3845
private val readBuffer = ReadBuffer()
3946
private val initialized: AtomicBoolean = AtomicBoolean(false)
4047
private var readingJob: Job? = null
4148
private var sendingJob: Job? = null
49+
private var processingJob: Job? = null
4250

4351
private val coroutineContext: CoroutineContext = IODispatcher + SupervisorJob()
4452
private val scope = CoroutineScope(coroutineContext)
4553
private val readChannel = Channel<ByteArray>(Channel.UNLIMITED)
4654
private val writeChannel = Channel<JSONRPCMessage>(Channel.UNLIMITED)
47-
private val outputWriter = outputStream.buffered()
55+
private val outputSink = outputStream.buffered()
4856

4957
override suspend fun start() {
5058
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
5159
error("StdioServerTransport already started!")
5260
}
5361

5462
// Launch a coroutine to read from stdin
55-
readingJob = scope.launch {
63+
readingJob = launchReadingJob()
64+
65+
// Launch a coroutine to process messages from readChannel
66+
processingJob = launchProcessingJob()
67+
68+
// Launch a coroutine to handle message sending
69+
sendingJob = launchSendingJob()
70+
}
71+
72+
private fun launchReadingJob(): Job {
73+
val job = scope.launch {
5674
val buf = Buffer()
75+
@Suppress("TooGenericExceptionCaught")
5776
try {
5877
while (isActive) {
59-
val bytesRead = inputStream.readAtMostTo(buf, 8192)
78+
val bytesRead = inputStream.readAtMostTo(buf, READ_BUFFER_SIZE)
6079
if (bytesRead == -1L) {
6180
// EOF reached
6281
break
@@ -66,6 +85,8 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
6685
readChannel.send(chunk)
6786
}
6887
}
88+
} catch (e: CancellationException) {
89+
throw e
6990
} catch (e: Throwable) {
7091
logger.error(e) { "Error reading from stdin" }
7192
_onError.invoke(e)
@@ -74,35 +95,59 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
7495
close()
7596
}
7697
}
98+
job.invokeOnCompletion { cause ->
99+
logJobCompletion("Message reading", cause)
100+
}
101+
return job
102+
}
77103

78-
// Launch a coroutine to process messages from readChannel
79-
scope.launch {
104+
private fun launchProcessingJob(): Job {
105+
val job = scope.launch {
106+
@Suppress("TooGenericExceptionCaught")
80107
try {
81108
for (chunk in readChannel) {
82109
readBuffer.append(chunk)
83110
processReadBuffer()
84111
}
112+
} catch (e: CancellationException) {
113+
throw e
85114
} catch (e: Throwable) {
86115
_onError.invoke(e)
87116
}
88117
}
118+
job.invokeOnCompletion { cause ->
119+
logJobCompletion("Processing", cause)
120+
}
121+
return job
122+
}
89123

90-
// Launch a coroutine to handle message sending
91-
sendingJob = scope.launch {
124+
private fun launchSendingJob(): Job {
125+
val job = scope.launch {
126+
@Suppress("TooGenericExceptionCaught")
92127
try {
93128
for (message in writeChannel) {
94129
val json = serializeMessage(message)
95-
outputWriter.writeString(json)
96-
outputWriter.flush()
130+
outputSink.writeString(json)
131+
outputSink.flush()
97132
}
133+
} catch (e: CancellationException) {
134+
throw e
98135
} catch (e: Throwable) {
99136
logger.error(e) { "Error writing to stdout" }
100137
_onError.invoke(e)
101138
}
102139
}
140+
job.invokeOnCompletion { cause ->
141+
logJobCompletion("Message sending", cause)
142+
if (cause is CancellationException) {
143+
readingJob?.cancel(cause)
144+
}
145+
}
146+
return job
103147
}
104148

105149
private suspend fun processReadBuffer() {
150+
@Suppress("TooGenericExceptionCaught")
106151
while (true) {
107152
val message = try {
108153
readBuffer.readMessage()
@@ -115,12 +160,30 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
115160
// Async invocation broke delivery order
116161
try {
117162
_onMessage.invoke(message)
163+
} catch (e: CancellationException) {
164+
throw e
118165
} catch (e: Throwable) {
166+
logger.error(e) { "Error processing message" }
119167
_onError.invoke(e)
120168
}
121169
}
122170
}
123171

172+
private fun logJobCompletion(jobName: String, cause: Throwable?) {
173+
when (cause) {
174+
is CancellationException -> {
175+
}
176+
177+
null -> {
178+
logger.debug { "$jobName job completed" }
179+
}
180+
181+
else -> {
182+
logger.debug(cause) { "$jobName job completed exceptionally" }
183+
}
184+
}
185+
}
186+
124187
override suspend fun close() {
125188
if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return
126189

@@ -133,13 +196,15 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
133196
}.onFailure { logger.warn(it) { "Failed to close stdin" } }
134197

135198
readingJob?.cancel()
136-
137199
readChannel.close()
200+
201+
processingJob?.cancelAndJoin()
202+
138203
readBuffer.clear()
139204

140205
runCatching {
141-
outputWriter.flush()
142-
outputWriter.close()
206+
outputSink.flush()
207+
outputSink.close()
143208
}.onFailure { logger.warn(it) { "Failed to close stdout" } }
144209

145210
invokeOnCloseCallback()

0 commit comments

Comments
 (0)