Skip to content

Commit 4d0d1bd

Browse files
committed
fix(server): #573, #574: add customizable builder to StdioServerTransport
Scope lifecycle is fixed (#574 resolved), processing dispatcher defaults to `Dispatchers.Default` (#573 resolved), scope context is clean (no spurious dispatcher stacking), and @volatile on the three job vars addresses the visibility race. - Introduce a `Configuration` class for `StdioServerTransport` to improve API flexibility and readability. - Updated transport initialization to use a builder block for configuring parameters such as I/O streams, buffer sizes, dispatchers, and parent coroutine scope. Test: Added integration test validating the builder functionality.
1 parent 3e7ea39 commit 4d0d1bd

3 files changed

Lines changed: 193 additions & 74 deletions

File tree

kotlin-sdk-server/api/kotlin-sdk-server.api

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,12 +186,35 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/SseServerTransport
186186
}
187187

188188
public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport {
189+
public fun <init> (Lkotlin/jvm/functions/Function1;)V
189190
public fun <init> (Lkotlinx/io/Source;Lkotlinx/io/Sink;)V
190191
public fun close (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
191192
public fun send (Lio/modelcontextprotocol/kotlin/sdk/types/JSONRPCMessage;Lio/modelcontextprotocol/kotlin/sdk/shared/TransportSendOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
192193
public fun start (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
193194
}
194195

196+
public final class io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport$Configuration {
197+
public fun <init> ()V
198+
public final fun getCoroutineScope ()Lkotlinx/coroutines/CoroutineScope;
199+
public final fun getProcessingJobDispatcher ()Lkotlinx/coroutines/CoroutineDispatcher;
200+
public final fun getReadBufferSize ()J
201+
public final fun getReadChannelBufferSize ()I
202+
public final fun getReadingJobDispatcher ()Lkotlinx/coroutines/CoroutineDispatcher;
203+
public final fun getSink ()Lkotlinx/io/Sink;
204+
public final fun getSource ()Lkotlinx/io/Source;
205+
public final fun getWriteChannelBufferSize ()I
206+
public final fun getWritingJobDispatcher ()Lkotlinx/coroutines/CoroutineDispatcher;
207+
public final fun setCoroutineScope (Lkotlinx/coroutines/CoroutineScope;)V
208+
public final fun setProcessingJobDispatcher (Lkotlinx/coroutines/CoroutineDispatcher;)V
209+
public final fun setReadBufferSize (J)V
210+
public final fun setReadChannelBufferSize (I)V
211+
public final fun setReadingJobDispatcher (Lkotlinx/coroutines/CoroutineDispatcher;)V
212+
public final fun setSink (Lkotlinx/io/Sink;)V
213+
public final fun setSource (Lkotlinx/io/Source;)V
214+
public final fun setWriteChannelBufferSize (I)V
215+
public final fun setWritingJobDispatcher (Lkotlinx/coroutines/CoroutineDispatcher;)V
216+
}
217+
195218
public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport : io/modelcontextprotocol/kotlin/sdk/shared/AbstractTransport {
196219
public static final field STANDALONE_SSE_STREAM_ID Ljava/lang/String;
197220
public fun <init> ()V

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt

Lines changed: 131 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions
88
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
99
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
1010
import kotlinx.coroutines.CancellationException
11+
import kotlinx.coroutines.CoroutineDispatcher
1112
import kotlinx.coroutines.CoroutineScope
13+
import kotlinx.coroutines.Dispatchers
1214
import kotlinx.coroutines.Job
1315
import kotlinx.coroutines.NonCancellable
1416
import kotlinx.coroutines.SupervisorJob
@@ -23,9 +25,9 @@ import kotlinx.io.Source
2325
import kotlinx.io.buffered
2426
import kotlinx.io.readByteArray
2527
import kotlinx.io.writeString
28+
import kotlin.concurrent.Volatile
2629
import kotlin.concurrent.atomics.AtomicBoolean
2730
import kotlin.concurrent.atomics.ExperimentalAtomicApi
28-
import kotlin.coroutines.CoroutineContext
2931

3032
private const val READ_BUFFER_SIZE = 8192L
3133

@@ -34,25 +36,91 @@ private const val READ_BUFFER_SIZE = 8192L
3436
*
3537
* Reads from input [Source] and writes to output [Sink].
3638
*
37-
* @constructor Creates a new instance of [StdioServerTransport].
38-
* @param inputStream The input [Source] used to receive data.
39-
* @param outputStream The output [Sink] used to send data.
39+
* @constructor Initializes the transport using the provided block for configuration.
40+
* The configuration includes specifying the input and output streams, buffer sizes,
41+
* and dispatchers for I/O and processing tasks.
4042
*/
4143
@OptIn(ExperimentalAtomicApi::class)
42-
public class StdioServerTransport(private val inputStream: Source, outputStream: Sink) : AbstractTransport() {
44+
public class StdioServerTransport(block: Configuration.() -> Unit) : AbstractTransport() {
45+
46+
/**
47+
* Configuration for [StdioServerTransport].
48+
*
49+
* @property source The input [Source] used to receive data.
50+
* @property sink The output [Sink] used to send data.
51+
* @property readBufferSize The buffer size for the read channel.
52+
* @property readingJobDispatcher The [CoroutineDispatcher] used for reading jobs.
53+
* Defaults to [IODispatcher].
54+
* @property writingJobDispatcher The [CoroutineDispatcher] used for writing jobs.
55+
* Defaults to [IODispatcher].
56+
* @property processingJobDispatcher The [CoroutineDispatcher] used for processing jobs.
57+
* Defaults to [Dispatchers.Default].
58+
* @property readChannelBufferSize The buffer size for the read channel.
59+
* @property writeChannelBufferSize The buffer size for the write channel.
60+
* @property coroutineScope The [CoroutineScope] used for managing coroutines.
61+
*/
62+
@Suppress("LongParameterList")
63+
public class Configuration internal constructor(
64+
public var source: Source? = null,
65+
public var sink: Sink? = null,
66+
public var readBufferSize: Long = READ_BUFFER_SIZE,
67+
public var readingJobDispatcher: CoroutineDispatcher = IODispatcher,
68+
public var writingJobDispatcher: CoroutineDispatcher = IODispatcher,
69+
public var processingJobDispatcher: CoroutineDispatcher = Dispatchers.Default,
70+
public var readChannelBufferSize: Int = Channel.UNLIMITED,
71+
public var writeChannelBufferSize: Int = Channel.UNLIMITED,
72+
public var coroutineScope: CoroutineScope? = null,
73+
)
74+
75+
private val source: Source
76+
private val sink: Sink
77+
private val processingJobDispatcher: CoroutineDispatcher
78+
private val readingJobDispatcher: CoroutineDispatcher
79+
private val writingJobDispatcher: CoroutineDispatcher
80+
private val scope: CoroutineScope
81+
private val readBufferSize: Long
82+
private val readChannel: Channel<ByteArray>
83+
private val writeChannel: Channel<JSONRPCMessage>
84+
85+
init {
86+
val config = Configuration().apply(block)
87+
val input = requireNotNull(config.source) { "source is required" }
88+
val output = requireNotNull(config.sink) { "sink is required" }
89+
require(config.readBufferSize > 0) { "readBufferSize must be > 0" }
90+
91+
source = input
92+
processingJobDispatcher = config.processingJobDispatcher
93+
readingJobDispatcher = config.readingJobDispatcher
94+
writingJobDispatcher = config.writingJobDispatcher
95+
val parentJob = config.coroutineScope?.coroutineContext?.get(Job)
96+
scope = CoroutineScope(SupervisorJob(parentJob))
97+
readBufferSize = config.readBufferSize
98+
readChannel = Channel(config.readChannelBufferSize)
99+
writeChannel = Channel(config.writeChannelBufferSize)
100+
sink = output.buffered()
101+
}
102+
103+
/**
104+
* Creates a new instance of [StdioServerTransport]
105+
* with the given [inputStream] [Source] and [outputStream] [Sink].
106+
*/
107+
public constructor(inputStream: Source, outputStream: Sink) : this({
108+
source = inputStream
109+
sink = outputStream
110+
})
43111

44112
private val logger = KotlinLogging.logger {}
45113
private val readBuffer = ReadBuffer()
46114
private val initialized: AtomicBoolean = AtomicBoolean(false)
115+
116+
@Volatile
47117
private var readingJob: Job? = null
118+
119+
@Volatile
48120
private var sendingJob: Job? = null
49-
private var processingJob: Job? = null
50121

51-
private val coroutineContext: CoroutineContext = IODispatcher + SupervisorJob()
52-
private val scope = CoroutineScope(coroutineContext)
53-
private val readChannel = Channel<ByteArray>(Channel.UNLIMITED)
54-
private val writeChannel = Channel<JSONRPCMessage>(Channel.UNLIMITED)
55-
private val outputSink = outputStream.buffered()
122+
@Volatile
123+
private var processingJob: Job? = null
56124

57125
override suspend fun start() {
58126
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
@@ -69,81 +137,71 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
69137
sendingJob = launchSendingJob()
70138
}
71139

72-
private fun launchReadingJob(): Job {
73-
val job = scope.launch {
74-
val buf = Buffer()
75-
@Suppress("TooGenericExceptionCaught")
76-
try {
77-
while (isActive) {
78-
val bytesRead = inputStream.readAtMostTo(buf, READ_BUFFER_SIZE)
79-
if (bytesRead == -1L) {
80-
// EOF reached
81-
break
82-
}
83-
if (bytesRead > 0) {
84-
val chunk = buf.readByteArray()
85-
readChannel.send(chunk)
86-
}
140+
private fun launchReadingJob(): Job = scope.launch(readingJobDispatcher) {
141+
val buf = Buffer()
142+
@Suppress("TooGenericExceptionCaught")
143+
try {
144+
while (isActive) {
145+
val bytesRead = source.readAtMostTo(buf, readBufferSize)
146+
if (bytesRead == -1L) {
147+
// EOF reached
148+
break
149+
}
150+
if (bytesRead > 0) {
151+
val chunk = buf.readByteArray()
152+
readChannel.send(chunk)
87153
}
88-
} catch (e: CancellationException) {
89-
throw e
90-
} catch (e: Throwable) {
91-
logger.error(e) { "Error reading from stdin" }
92-
_onError.invoke(e)
93-
} finally {
94-
// Reached EOF or error, close connection
95-
close()
96154
}
155+
} catch (e: CancellationException) {
156+
throw e
157+
} catch (e: Throwable) {
158+
logger.error(e) { "Error reading from stdin" }
159+
_onError.invoke(e)
160+
} finally {
161+
// Reached EOF or error, close connection
162+
close()
97163
}
98-
job.invokeOnCompletion { cause ->
99-
logJobCompletion("Message reading", cause)
100-
}
101-
return job
164+
}.apply {
165+
invokeOnCompletion { logJobCompletion("Message reading", it) }
102166
}
103167

104-
private fun launchProcessingJob(): Job {
105-
val job = scope.launch {
106-
@Suppress("TooGenericExceptionCaught")
107-
try {
108-
for (chunk in readChannel) {
109-
readBuffer.append(chunk)
110-
processReadBuffer()
111-
}
112-
} catch (e: CancellationException) {
113-
throw e
114-
} catch (e: Throwable) {
115-
_onError.invoke(e)
168+
private fun launchProcessingJob(): Job = scope.launch(processingJobDispatcher) {
169+
@Suppress("TooGenericExceptionCaught")
170+
try {
171+
for (chunk in readChannel) {
172+
readBuffer.append(chunk)
173+
processReadBuffer()
116174
}
175+
} catch (e: CancellationException) {
176+
throw e
177+
} catch (e: Throwable) {
178+
_onError.invoke(e)
117179
}
118-
job.invokeOnCompletion { cause ->
119-
logJobCompletion("Processing", cause)
120-
}
121-
return job
180+
}.apply {
181+
invokeOnCompletion { logJobCompletion("Processing", it) }
122182
}
123183

124-
private fun launchSendingJob(): Job {
125-
val job = scope.launch {
126-
@Suppress("TooGenericExceptionCaught")
127-
try {
128-
for (message in writeChannel) {
129-
val json = serializeMessage(message)
130-
outputSink.writeString(json)
131-
outputSink.flush()
132-
}
133-
} catch (e: CancellationException) {
134-
throw e
135-
} catch (e: Throwable) {
136-
logger.error(e) { "Error writing to stdout" }
137-
_onError.invoke(e)
184+
private fun launchSendingJob(): Job = scope.launch(writingJobDispatcher) {
185+
@Suppress("TooGenericExceptionCaught")
186+
try {
187+
for (message in writeChannel) {
188+
val json = serializeMessage(message)
189+
sink.writeString(json)
190+
sink.flush()
138191
}
192+
} catch (e: CancellationException) {
193+
throw e
194+
} catch (e: Throwable) {
195+
logger.error(e) { "Error writing to stdout" }
196+
_onError.invoke(e)
139197
}
140-
job.invokeOnCompletion { cause ->
198+
}.apply {
199+
invokeOnCompletion { cause ->
141200
logJobCompletion("Message sending", cause)
142201
if (cause is CancellationException) {
143202
readingJob?.cancel(cause)
144203
}
145204
}
146-
return job
147205
}
148206

149207
private suspend fun processReadBuffer() {
@@ -192,7 +250,7 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
192250
sendingJob?.cancelAndJoin()
193251

194252
runCatching {
195-
inputStream.close()
253+
source.close()
196254
}.onFailure { logger.warn(it) { "Failed to close stdin" } }
197255

198256
readingJob?.cancel()
@@ -201,10 +259,9 @@ public class StdioServerTransport(private val inputStream: Source, outputStream:
201259
processingJob?.cancelAndJoin()
202260

203261
readBuffer.clear()
204-
205262
runCatching {
206-
outputSink.flush()
207-
outputSink.close()
263+
sink.flush()
264+
sink.close()
208265
}.onFailure { logger.warn(it) { "Failed to close stdout" } }
209266

210267
invokeOnCloseCallback()

kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ import io.kotest.assertions.throwables.shouldThrow
55
import io.kotest.assertions.withClue
66
import io.kotest.matchers.collections.shouldContain
77
import io.kotest.matchers.shouldBe
8+
import io.ktor.client.utils.clientDispatcher
9+
import io.ktor.utils.io.InternalAPI
810
import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer
911
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
1012
import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification
@@ -14,7 +16,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.toJSON
1416
import io.modelcontextprotocol.kotlin.test.utils.runIntegrationTest
1517
import kotlinx.coroutines.CancellationException
1618
import kotlinx.coroutines.CompletableDeferred
19+
import kotlinx.coroutines.CoroutineScope
20+
import kotlinx.coroutines.Dispatchers
1721
import kotlinx.coroutines.TimeoutCancellationException
22+
import kotlinx.coroutines.channels.Channel
1823
import kotlinx.coroutines.withTimeout
1924
import kotlinx.io.Buffer
2025
import kotlinx.io.RawSink
@@ -69,6 +74,40 @@ class StdioServerTransportTest {
6974
printOutput = output.asSink().buffered()
7075
}
7176

77+
@OptIn(InternalAPI::class)
78+
@Test
79+
fun `should construct with builder`() = runIntegrationTest {
80+
val received = CompletableDeferred<JSONRPCMessage>()
81+
82+
val ioDispatcher = Dispatchers.IO.limitedParallelism(4)
83+
84+
// Set every configuration parameter explicitly with non-default values,
85+
// then verify a message round-trips correctly.
86+
val server = StdioServerTransport {
87+
source = bufferedInput
88+
sink = printOutput
89+
readBufferSize = 16L // non-default: smaller read chunk
90+
readingJobDispatcher = ioDispatcher // non-default: limited parallelism
91+
writingJobDispatcher = ioDispatcher // non-default: limited parallelism
92+
processingJobDispatcher = Dispatchers.clientDispatcher(2, "Worker") // non-default
93+
readChannelBufferSize = Channel.BUFFERED // non-default: bounded
94+
writeChannelBufferSize = Channel.BUFFERED // non-default: bounded
95+
coroutineScope = CoroutineScope(Dispatchers.Default) // non-default: parent scope provided
96+
}
97+
server.onError { throw it }
98+
server.onMessage { received.complete(it) }
99+
100+
server.start()
101+
102+
val message = PingRequest().toJSON()
103+
inputWriter.write(serializeMessage(message))
104+
inputWriter.flush()
105+
106+
received.await() shouldBe message
107+
108+
server.close()
109+
}
110+
72111
@Test
73112
fun `should be safe to close before start`() = runIntegrationTest {
74113
val server = StdioServerTransport(bufferedInput, printOutput)

0 commit comments

Comments
 (0)