diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt index 585af3f8d..fc3c1d3dd 100644 --- a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransportTest.kt @@ -7,9 +7,13 @@ import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.test.utils.createSleepyProcessBuilder import io.modelcontextprotocol.kotlin.test.utils.createTeeProcessBuilder import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.async import kotlinx.coroutines.runBlocking import kotlinx.coroutines.test.runTest import kotlinx.coroutines.withTimeout +import kotlinx.coroutines.withTimeoutOrNull +import kotlinx.io.Buffer +import kotlinx.io.RawSource import kotlinx.io.asSink import kotlinx.io.asSource import kotlinx.io.buffered @@ -18,10 +22,12 @@ import org.junit.jupiter.api.Timeout import org.junit.jupiter.api.assertThrows import org.junit.jupiter.api.condition.DisabledOnOs import org.junit.jupiter.api.condition.OS +import java.util.concurrent.CountDownLatch import java.util.concurrent.TimeUnit import kotlin.concurrent.atomics.AtomicBoolean import kotlin.concurrent.atomics.ExperimentalAtomicApi import kotlin.test.assertFalse +import kotlin.test.assertNotNull import kotlin.test.assertTrue import kotlin.test.fail import kotlin.time.Duration.Companion.seconds @@ -109,6 +115,43 @@ class StdioClientTransportTest : BaseTransportTest() { } } + @OptIn(ExperimentalAtomicApi::class) + @Test + fun `should close cleanly while stdin read is blocked`(): Unit = runBlocking(Dispatchers.IO) { + val input = BlockingRawSource() + + val transport = StdioClientTransport( + input = input.buffered(), + output = Buffer(), + ) + + val didClose = AtomicBoolean(false) + transport.onClose { didClose.store(true) } + transport.onError { error -> + fail("Unexpected error while closing transport: $error") + } + + transport.start() + + eventually(2.seconds) { + assertTrue(input.readStarted, "Transport should start reading stdin before close") + } + + val closeJob = async { + transport.close() + } + + val closed = withTimeoutOrNull(1.seconds) { + closeJob.await() + } + + input.close() + closeJob.await() + + assertNotNull(closed, "Transport.close() should not wait for stdin to produce data or EOF") + assertTrue(didClose.load(), "Transport should be closed after close() call") + } + @Test fun `should read messages`() = runTest { val processBuilder = createTeeProcessBuilder() @@ -147,4 +190,22 @@ class StdioClientTransportTest : BaseTransportTest() { process.waitFor() process.destroyForcibly() } + + private class BlockingRawSource : RawSource { + private val closed = CountDownLatch(1) + + @Volatile + var readStarted: Boolean = false + private set + + override fun readAtMostTo(sink: Buffer, byteCount: Long): Long { + readStarted = true + closed.await() + return -1L + } + + override fun close() { + closed.countDown() + } + } } diff --git a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt index 63f681c5c..5e990eea0 100644 --- a/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt +++ b/kotlin-sdk-client/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/client/StdioClientTransport.kt @@ -244,10 +244,20 @@ public class StdioClientTransport @JvmOverloads public constructor( override suspend fun closeResources() { withContext(NonCancellable) { scope.stopProcessing("Closed") + closeReadSources() scope.coroutineContext[Job]?.join() // Wait for all coroutines to complete } } + private fun closeReadSources() { + runCatching { input.close() } + .onFailure { logger.debug(it) { "Error closing stdin source" } } + error?.let { source -> + runCatching { source.close() } + .onFailure { logger.debug(it) { "Error closing stderr source" } } + } + } + private fun sendOutboundMessage(message: JSONRPCMessage, sink: Sink, mainScope: CoroutineScope) { try { val json = serializeMessage(message)