Skip to content

Commit e7b09d8

Browse files
authored
fix: propagate CancellationException instead of logging as error (#706)
fixes #242 ## How Has This Been Tested? new and existing tests pass ## Breaking Changes none ## Types of changes - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update ## Checklist - [x] I have read the [MCP Documentation](https://modelcontextprotocol.io) - [x] My code follows the repository's style guidelines - [x] New and existing tests pass locally - [x] I have added appropriate error handling - [x] I have added or updated documentation as needed
1 parent bce6811 commit e7b09d8

7 files changed

Lines changed: 109 additions & 25 deletions

File tree

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/SseClientTransport.kt

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ import kotlinx.coroutines.CoroutineName
2424
import kotlinx.coroutines.CoroutineScope
2525
import kotlinx.coroutines.ExperimentalCoroutinesApi
2626
import kotlinx.coroutines.Job
27+
import kotlinx.coroutines.NonCancellable
2728
import kotlinx.coroutines.SupervisorJob
2829
import kotlinx.coroutines.cancel
29-
import kotlinx.coroutines.cancelAndJoin
3030
import kotlinx.coroutines.ensureActive
3131
import kotlinx.coroutines.launch
32+
import kotlinx.coroutines.withContext
3233
import kotlinx.serialization.SerializationException
3334
import kotlin.concurrent.atomics.ExperimentalAtomicApi
3435
import kotlin.time.Duration
@@ -171,13 +172,15 @@ public class SseClientTransport(
171172
}
172173

173174
override suspend fun closeResources() {
174-
job?.cancelAndJoin()
175-
try {
176-
if (::session.isInitialized) session.cancel()
177-
if (::scope.isInitialized) scope.cancel()
178-
endpoint.cancel()
179-
} catch (e: Throwable) {
180-
_onError(e)
175+
withContext(NonCancellable) {
176+
job?.cancel()
177+
try {
178+
if (::session.isInitialized) session.cancel()
179+
if (::scope.isInitialized) scope.cancel()
180+
endpoint.cancel()
181+
} catch (e: Throwable) {
182+
_onError(e)
183+
}
181184
}
182185
}
183186
}

kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import kotlinx.coroutines.CoroutineName
2121
import kotlinx.coroutines.CoroutineScope
2222
import kotlinx.coroutines.Dispatchers
2323
import kotlinx.coroutines.Job
24+
import kotlinx.coroutines.NonCancellable
2425
import kotlinx.coroutines.SupervisorJob
2526
import kotlinx.coroutines.cancel
2627
import kotlinx.coroutines.cancelAndJoin
@@ -31,6 +32,7 @@ import kotlinx.coroutines.channels.consumeEach
3132
import kotlinx.coroutines.flow.channelFlow
3233
import kotlinx.coroutines.isActive
3334
import kotlinx.coroutines.launch
35+
import kotlinx.coroutines.withContext
3436
import kotlinx.coroutines.yield
3537
import kotlinx.io.Buffer
3638
import kotlinx.io.IOException
@@ -241,8 +243,10 @@ public class StdioClientTransport @JvmOverloads public constructor(
241243
}
242244

243245
override suspend fun closeResources() {
244-
scope.stopProcessing("Closed")
245-
scope.coroutineContext[Job]?.join() // Wait for all coroutines to complete
246+
withContext(NonCancellable) {
247+
scope.stopProcessing("Closed")
248+
scope.coroutineContext[Job]?.join() // Wait for all coroutines to complete
249+
}
246250
}
247251

248252
private fun sendOutboundMessage(message: JSONRPCMessage, sink: Sink, mainScope: CoroutineScope) {
@@ -264,6 +268,8 @@ public class StdioClientTransport @JvmOverloads public constructor(
264268
private suspend fun handleJSONRPCMessage(msg: JSONRPCMessage) {
265269
try {
266270
_onMessage.invoke(msg)
271+
} catch (e: CancellationException) {
272+
throw e
267273
} catch (e: Throwable) {
268274
logger.error(e) { "Error processing message." }
269275
runCatching { _onError.invoke(e) }

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,15 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
277277
}
278278
try {
279279
handler(notification)
280+
} catch (e: CancellationException) {
281+
throw e
280282
} catch (cause: Throwable) {
281283
logger.error(cause) { "Error handling notification: ${notification.method}" }
282284
onError(cause)
283285
}
284286
}
285287

288+
@Suppress("ThrowsCount")
286289
private suspend fun onRequest(request: JSONRPCRequest) {
287290
logger.trace { "Received request: ${request.method} (id: ${request.id})" }
288291

@@ -300,6 +303,8 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
300303
),
301304
),
302305
)
306+
} catch (e: CancellationException) {
307+
throw e
303308
} catch (cause: Throwable) {
304309
logger.error(cause) { "Error sending method not found response" }
305310
onError(cause)

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/WebSocketMcpTransport.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import kotlinx.coroutines.job
1616
import kotlinx.coroutines.launch
1717
import kotlin.concurrent.atomics.AtomicBoolean
1818
import kotlin.concurrent.atomics.ExperimentalAtomicApi
19+
import kotlin.coroutines.cancellation.CancellationException
1920

2021
/** WebSocket subprotocol identifier for MCP connections. */
2122
public const val MCP_SUBPROTOCOL: String = "mcp"
@@ -44,6 +45,7 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
4445
*/
4546
protected abstract suspend fun initializeSession()
4647

48+
@Suppress("ThrowsCount")
4749
override suspend fun start() {
4850
logger.debug { "Starting websocket transport" }
4951

@@ -74,6 +76,8 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
7476
try {
7577
val message = McpJson.decodeFromString<JSONRPCMessage>(message.readText())
7678
_onMessage.invoke(message)
79+
} catch (e: CancellationException) {
80+
throw e
7781
} catch (e: Exception) {
7882
_onError.invoke(e)
7983
throw e

kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/ProtocolTest.kt

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package io.modelcontextprotocol.kotlin.sdk.shared
22

3+
import io.kotest.assertions.throwables.shouldThrow
34
import io.kotest.matchers.collections.shouldContainExactly
5+
import io.kotest.matchers.collections.shouldHaveSize
46
import io.kotest.matchers.nulls.shouldNotBeNull
57
import io.kotest.matchers.shouldBe
68
import io.modelcontextprotocol.kotlin.sdk.types.CustomRequest
79
import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult
810
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
11+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCNotification
912
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest
1013
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse
1114
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
@@ -24,6 +27,7 @@ import kotlinx.serialization.json.encodeToJsonElement
2427
import kotlinx.serialization.json.int
2528
import kotlinx.serialization.json.jsonObject
2629
import kotlinx.serialization.json.jsonPrimitive
30+
import kotlin.coroutines.cancellation.CancellationException
2731
import kotlin.test.BeforeTest
2832
import kotlin.test.Test
2933

@@ -126,6 +130,36 @@ class ProtocolTest {
126130
inFlight.await()
127131
}
128132

133+
@Test
134+
fun `should propagate CancellationException from notification handler without calling onError`() = runTest {
135+
protocol.connect(transport)
136+
137+
protocol.fallbackNotificationHandler = {
138+
throw CancellationException("test cancellation")
139+
}
140+
141+
shouldThrow<CancellationException> {
142+
transport.deliver(JSONRPCNotification(method = "test/notification"))
143+
}
144+
145+
protocol.errors shouldHaveSize 0
146+
}
147+
148+
@Test
149+
fun `should report non-cancellation exception from notification handler via onError`() = runTest {
150+
protocol.connect(transport)
151+
152+
protocol.fallbackNotificationHandler = {
153+
throw IllegalStateException("handler failed")
154+
}
155+
156+
// Non-CE exceptions are caught and reported, not propagated
157+
transport.deliver(JSONRPCNotification(method = "test/notification"))
158+
159+
protocol.errors shouldHaveSize 1
160+
protocol.errors[0].message shouldBe "handler failed"
161+
}
162+
129163
@Test
130164
fun `should create params object when request params are null`() = runTest {
131165
protocol.connect(transport)
@@ -154,6 +188,12 @@ class ProtocolTest {
154188
}
155189

156190
private class TestProtocol : Protocol(null) {
191+
val errors = mutableListOf<Throwable>()
192+
193+
override fun onError(error: Throwable) {
194+
errors.add(error)
195+
}
196+
157197
override fun assertCapabilityForMethod(method: Method) {
158198
// noop
159199
}

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SSEServerTransport.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ public class SseServerTransport(private val endpoint: String, private val sessio
8484
}
8585

8686
call.receiveText()
87+
} catch (e: CancellationException) {
88+
throw e
8789
} catch (e: Exception) {
8890
call.respondText("Invalid message: ${e.message}", status = HttpStatusCode.BadRequest)
8991
_onError.invoke(e)
@@ -92,6 +94,8 @@ public class SseServerTransport(private val endpoint: String, private val sessio
9294

9395
try {
9496
handleMessage(body)
97+
} catch (e: CancellationException) {
98+
throw e
9599
} catch (e: Exception) {
96100
call.respondText("Error handling message $body: ${e.message}", status = HttpStatusCode.BadRequest)
97101
return
@@ -108,6 +112,8 @@ public class SseServerTransport(private val endpoint: String, private val sessio
108112
try {
109113
val parsedMessage = McpJson.decodeFromString<JSONRPCMessage>(message)
110114
_onMessage.invoke(parsedMessage)
115+
} catch (e: CancellationException) {
116+
throw e
111117
} catch (e: Exception) {
112118
_onError.invoke(e)
113119
throw e

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,18 @@ import io.modelcontextprotocol.kotlin.sdk.types.RPCError
2828
import io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.REQUEST_TIMEOUT
2929
import io.modelcontextprotocol.kotlin.sdk.types.RequestId
3030
import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS
31+
import kotlinx.coroutines.NonCancellable
3132
import kotlinx.coroutines.awaitCancellation
3233
import kotlinx.coroutines.job
3334
import kotlinx.coroutines.sync.Mutex
3435
import kotlinx.coroutines.sync.withLock
36+
import kotlinx.coroutines.withContext
3537
import kotlinx.serialization.json.JsonArray
3638
import kotlinx.serialization.json.JsonObject
3739
import kotlinx.serialization.json.decodeFromJsonElement
3840
import kotlin.concurrent.atomics.AtomicBoolean
3941
import kotlin.concurrent.atomics.ExperimentalAtomicApi
42+
import kotlin.coroutines.cancellation.CancellationException
4043
import kotlin.time.Duration
4144
import kotlin.time.Duration.Companion.milliseconds
4245
import kotlin.uuid.ExperimentalUuidApi
@@ -283,17 +286,19 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
283286
}
284287

285288
override suspend fun close() {
286-
streamMutex.withLock {
287-
streamsMapping.values.forEach {
288-
try {
289-
it.session?.close()
290-
} catch (_: Exception) {
289+
withContext(NonCancellable) {
290+
streamMutex.withLock {
291+
streamsMapping.values.forEach {
292+
try {
293+
it.session?.close()
294+
} catch (_: Exception) {
295+
}
291296
}
297+
streamsMapping.clear()
298+
requestToStreamMapping.clear()
299+
requestToResponseMapping.clear()
300+
invokeOnCloseCallback()
292301
}
293-
streamsMapping.clear()
294-
requestToStreamMapping.clear()
295-
requestToResponseMapping.clear()
296-
invokeOnCloseCallback()
297302
}
298303
}
299304

@@ -420,6 +425,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
420425
call.coroutineContext.job.invokeOnCompletion { streamsMapping.remove(streamId) }
421426

422427
messages.forEach { message -> _onMessage(message) }
428+
} catch (e: CancellationException) {
429+
throw e
423430
} catch (e: Exception) {
424431
call.reject(
425432
HttpStatusCode.BadRequest,
@@ -498,12 +505,14 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
498505
val streamId = requestToStreamMapping[requestId] ?: return
499506
val sessionContext = streamsMapping[streamId] ?: return
500507

501-
try {
502-
sessionContext.session?.close()
503-
} catch (e: Exception) {
504-
_onError(e)
505-
} finally {
506-
streamsMapping.remove(streamId)
508+
withContext(NonCancellable) {
509+
try {
510+
sessionContext.session?.close()
511+
} catch (e: Exception) {
512+
_onError(e)
513+
} finally {
514+
streamsMapping.remove(streamId)
515+
}
507516
}
508517
}
509518

@@ -554,6 +563,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
554563
id = eventId,
555564
data = McpJson.encodeToString(message),
556565
)
566+
} catch (e: CancellationException) {
567+
throw e
557568
} catch (e: Exception) {
558569
_onError(IllegalStateException("Failed to replay event: ${e.message}", e))
559570
}
@@ -565,6 +576,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
565576
streamsMapping.remove(streamId)
566577
throwable?.let { _onError(it) }
567578
}
579+
} catch (e: CancellationException) {
580+
throw e
568581
} catch (e: Exception) {
569582
_onError(e)
570583
}
@@ -658,6 +671,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
658671
private suspend fun flushSse(session: ServerSSESession?) {
659672
try {
660673
session?.send(data = "")
674+
} catch (e: CancellationException) {
675+
throw e
661676
} catch (e: Exception) {
662677
_onError(e)
663678
}
@@ -709,6 +724,9 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
709724
val eventId = configuration.eventStore?.storeEvent(streamId, message)
710725
try {
711726
session?.send(event = "message", id = eventId, data = McpJson.encodeToString(message))
727+
} catch (e: CancellationException) {
728+
streamsMapping.remove(streamId)
729+
throw e
712730
} catch (_: Exception) {
713731
streamsMapping.remove(streamId)
714732
}
@@ -733,6 +751,8 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
733751
retry = configuration.retryInterval?.inWholeMilliseconds,
734752
data = "",
735753
)
754+
} catch (e: CancellationException) {
755+
throw e
736756
} catch (e: Exception) {
737757
_onError(e)
738758
}

0 commit comments

Comments
 (0)