Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.test.utils.createSleepyProcessBuilder
import io.modelcontextprotocol.kotlin.test.utils.createTeeProcessBuilder
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.async
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withTimeout
import kotlinx.coroutines.withTimeoutOrNull
import kotlinx.io.Buffer
import kotlinx.io.RawSource
import kotlinx.io.asSink
import kotlinx.io.asSource
import kotlinx.io.buffered
Expand All @@ -18,10 +22,12 @@ import org.junit.jupiter.api.Timeout
import org.junit.jupiter.api.assertThrows
import org.junit.jupiter.api.condition.DisabledOnOs
import org.junit.jupiter.api.condition.OS
import java.util.concurrent.CountDownLatch
import java.util.concurrent.TimeUnit
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.test.assertFalse
import kotlin.test.assertNotNull
import kotlin.test.assertTrue
import kotlin.test.fail
import kotlin.time.Duration.Companion.seconds
Expand Down Expand Up @@ -109,6 +115,43 @@ class StdioClientTransportTest : BaseTransportTest() {
}
}

@OptIn(ExperimentalAtomicApi::class)
@Test
fun `should close cleanly while stdin read is blocked`(): Unit = runBlocking(Dispatchers.IO) {
val input = BlockingRawSource()

val transport = StdioClientTransport(
input = input.buffered(),
output = Buffer(),
)

val didClose = AtomicBoolean(false)
transport.onClose { didClose.store(true) }
transport.onError { error ->
fail("Unexpected error while closing transport: $error")
}

transport.start()

eventually(2.seconds) {
assertTrue(input.readStarted, "Transport should start reading stdin before close")
}

val closeJob = async {
transport.close()
}

val closed = withTimeoutOrNull(1.seconds) {
closeJob.await()
}

input.close()
closeJob.await()

assertNotNull(closed, "Transport.close() should not wait for stdin to produce data or EOF")
assertTrue(didClose.load(), "Transport should be closed after close() call")
}

@Test
fun `should read messages`() = runTest {
val processBuilder = createTeeProcessBuilder()
Expand Down Expand Up @@ -147,4 +190,22 @@ class StdioClientTransportTest : BaseTransportTest() {
process.waitFor()
process.destroyForcibly()
}

private class BlockingRawSource : RawSource {
private val closed = CountDownLatch(1)

@Volatile
var readStarted: Boolean = false
private set

override fun readAtMostTo(sink: Buffer, byteCount: Long): Long {
readStarted = true
closed.await()
return -1L
}

override fun close() {
closed.countDown()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,20 @@ public class StdioClientTransport @JvmOverloads public constructor(
override suspend fun closeResources() {
withContext(NonCancellable) {
scope.stopProcessing("Closed")
closeReadSources()
scope.coroutineContext[Job]?.join() // Wait for all coroutines to complete
}
}

private fun closeReadSources() {
runCatching { input.close() }
.onFailure { logger.debug(it) { "Error closing stdin source" } }
error?.let { source ->
runCatching { source.close() }
.onFailure { logger.debug(it) { "Error closing stderr source" } }
}
}

private fun sendOutboundMessage(message: JSONRPCMessage, sink: Sink, mainScope: CoroutineScope) {
try {
val json = serializeMessage(message)
Expand Down