Skip to content

Commit ac552f7

Browse files
committed
refactor(server): simplify StdioServerTransport constructor and remove Configuration class
- Replaced the builder-based setup with a streamlined primary constructor for `StdioServerTransport`. - Removed the `Configuration` class for reduced complexity and enhanced readability. - Updated `StdioServerTransportTest` to reflect the refactored initialization. - Refactor cleanup.
1 parent 570184d commit ac552f7

2 files changed

Lines changed: 59 additions & 98 deletions

File tree

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransport.kt

Lines changed: 47 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import kotlinx.coroutines.Dispatchers
1414
import kotlinx.coroutines.Job
1515
import kotlinx.coroutines.NonCancellable
1616
import kotlinx.coroutines.SupervisorJob
17-
import kotlinx.coroutines.cancelAndJoin
17+
import kotlinx.coroutines.cancel
1818
import kotlinx.coroutines.channels.Channel
1919
import kotlinx.coroutines.isActive
2020
import kotlinx.coroutines.launch
@@ -25,7 +25,6 @@ import kotlinx.io.Source
2525
import kotlinx.io.buffered
2626
import kotlinx.io.readByteArray
2727
import kotlinx.io.writeString
28-
import kotlin.concurrent.Volatile
2928
import kotlin.concurrent.atomics.AtomicBoolean
3029
import kotlin.concurrent.atomics.ExperimentalAtomicApi
3130

@@ -34,115 +33,80 @@ private const val READ_BUFFER_SIZE = 8192L
3433
/**
3534
* A server transport that communicates with a client via standard I/O.
3635
*
37-
* Reads from input [Source] and writes to output [Sink].
36+
* [StdioServerTransport] manages the communication between a JSON-RPC server and its clients
37+
* by reading incoming messages from the specified [Source] (input stream) and writing outgoing
38+
* messages to the [Sink] (output stream).
3839
*
3940
* Example:
4041
* ```kotlin
41-
* val transport = StdioServerTransport {
42+
* val transport = StdioServerTransport(
4243
* source = System.`in`.asInput(),
4344
* sink = System.out.asSink(),
44-
* }
45+
* )
4546
* ```
4647
*
47-
* @constructor Initializes the transport using the provided block for [Configuration].
48-
* The configuration includes specifying the input and output streams, buffer sizes,
49-
* and dispatchers for I/O and processing tasks and coroutine scope.
48+
* @constructor Creates an instance of [StdioServerTransport] with the specified parameters.
49+
* @param source The source for reading incoming messages (e.g., stdin or other readable stream).
50+
* @param sink The sink for writing outgoing messages (e.g., stdout or other writable stream).
51+
* @param readBufferSize The maximum size of the read buffer, defaults to a pre-configured constant.
52+
* @param readChannel The channel for receiving raw byte arrays from the input stream.
53+
* @param writeChannel The channel for sending serialized JSON-RPC messages to the output stream.
54+
* @param readingJobDispatcher The dispatcher to use for the message-reading coroutine.
55+
* @param writingJobDispatcher The dispatcher to use for the message-writing coroutine.
56+
* @param processingJobDispatcher The dispatcher to handle processing of read messages.
57+
* @param coroutineScope Optional coroutine scope to use for managing internal jobs. A new scope
58+
* will be created if not provided.
5059
*/
5160
@OptIn(ExperimentalAtomicApi::class)
52-
public class StdioServerTransport(block: Configuration.() -> Unit) : AbstractTransport() {
61+
@Suppress("LongParameterList")
62+
public class StdioServerTransport(
63+
private val source: Source,
64+
sink: Sink,
65+
private val readBufferSize: Long = READ_BUFFER_SIZE,
66+
private val readChannel: Channel<ByteArray> = Channel(Channel.UNLIMITED),
67+
private val writeChannel: Channel<JSONRPCMessage> = Channel(Channel.UNLIMITED),
68+
private var readingJobDispatcher: CoroutineDispatcher = IODispatcher,
69+
private var writingJobDispatcher: CoroutineDispatcher = IODispatcher,
70+
private var processingJobDispatcher: CoroutineDispatcher = Dispatchers.Default,
71+
coroutineScope: CoroutineScope? = null,
72+
) : AbstractTransport() {
5373

54-
/**
55-
* Configuration for [StdioServerTransport].
56-
*
57-
* @property source The input [Source] used to receive data.
58-
* @property sink The output [Sink] used to send data.
59-
* @property readBufferSize The buffer size for the read channel.
60-
* @property readingJobDispatcher The [CoroutineDispatcher] used for reading jobs.
61-
* Defaults to [IODispatcher].
62-
* @property writingJobDispatcher The [CoroutineDispatcher] used for writing jobs.
63-
* Defaults to [IODispatcher].
64-
* @property processingJobDispatcher The [CoroutineDispatcher] used for processing jobs.
65-
* Defaults to [Dispatchers.Default].
66-
* @property readChannelBufferSize The buffer size for the read channel.
67-
* @property writeChannelBufferSize The buffer size for the write channel.
68-
* @property coroutineScope The [CoroutineScope] used for managing coroutines.
69-
*/
70-
@Suppress("LongParameterList")
71-
public class Configuration internal constructor(
72-
public var source: Source? = null,
73-
public var sink: Sink? = null,
74-
public var readBufferSize: Long = READ_BUFFER_SIZE,
75-
public var readingJobDispatcher: CoroutineDispatcher = IODispatcher,
76-
public var writingJobDispatcher: CoroutineDispatcher = IODispatcher,
77-
public var processingJobDispatcher: CoroutineDispatcher = Dispatchers.Default,
78-
public var readChannelBufferSize: Int = Channel.UNLIMITED,
79-
public var writeChannelBufferSize: Int = Channel.UNLIMITED,
80-
public var coroutineScope: CoroutineScope? = null,
81-
)
82-
83-
private val source: Source
84-
private val sink: Sink
85-
private val processingJobDispatcher: CoroutineDispatcher
86-
private val readingJobDispatcher: CoroutineDispatcher
87-
private val writingJobDispatcher: CoroutineDispatcher
8874
private val scope: CoroutineScope
89-
private val readBufferSize: Long
90-
private val readChannel: Channel<ByteArray>
91-
private val writeChannel: Channel<JSONRPCMessage>
75+
private val sink: Sink
9276

9377
init {
94-
val config = Configuration().apply(block)
95-
val input = requireNotNull(config.source) { "source is required" }
96-
val output = requireNotNull(config.sink) { "sink is required" }
97-
require(config.readBufferSize > 0) { "readBufferSize must be > 0" }
98-
99-
source = input
100-
processingJobDispatcher = config.processingJobDispatcher
101-
readingJobDispatcher = config.readingJobDispatcher
102-
writingJobDispatcher = config.writingJobDispatcher
103-
val parentJob = config.coroutineScope?.coroutineContext?.get(Job)
78+
require(readBufferSize > 0) { "readBufferSize must be > 0" }
79+
val parentJob = coroutineScope?.coroutineContext?.get(Job)
10480
scope = CoroutineScope(SupervisorJob(parentJob))
105-
readBufferSize = config.readBufferSize
106-
readChannel = Channel(config.readChannelBufferSize)
107-
writeChannel = Channel(config.writeChannelBufferSize)
108-
sink = output.buffered()
81+
this.sink = sink.buffered()
10982
}
11083

11184
/**
11285
* Creates a new instance of [StdioServerTransport]
11386
* with the given [inputStream] [Source] and [outputStream] [Sink].
11487
*/
115-
public constructor(inputStream: Source, outputStream: Sink) : this({
116-
source = inputStream
117-
sink = outputStream
118-
})
88+
public constructor(inputStream: Source, outputStream: Sink) : this(
89+
source = inputStream,
90+
sink = outputStream,
91+
)
11992

12093
private val logger = KotlinLogging.logger {}
12194
private val readBuffer = ReadBuffer()
12295
private val initialized: AtomicBoolean = AtomicBoolean(false)
12396

124-
@Volatile
125-
private var readingJob: Job? = null
126-
127-
@Volatile
128-
private var sendingJob: Job? = null
129-
130-
@Volatile
131-
private var processingJob: Job? = null
132-
13397
override suspend fun start() {
13498
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
13599
error("StdioServerTransport already started!")
136100
}
137101

138102
// Launch a coroutine to read from stdin
139-
readingJob = launchReadingJob()
103+
launchReadingJob()
140104

141105
// Launch a coroutine to process messages from readChannel
142-
processingJob = launchProcessingJob()
106+
launchProcessingJob()
143107

144108
// Launch a coroutine to handle message sending
145-
sendingJob = launchSendingJob()
109+
launchSendingJob()
146110
}
147111

148112
private fun launchReadingJob(): Job = scope.launch(readingJobDispatcher) {
@@ -186,7 +150,9 @@ public class StdioServerTransport(block: Configuration.() -> Unit) : AbstractTra
186150
_onError.invoke(e)
187151
}
188152
}.apply {
189-
invokeOnCompletion { logJobCompletion("Processing", it) }
153+
invokeOnCompletion { cause ->
154+
logJobCompletion("Processing", cause)
155+
}
190156
}
191157

192158
private fun launchSendingJob(): Job = scope.launch(writingJobDispatcher) {
@@ -207,7 +173,7 @@ public class StdioServerTransport(block: Configuration.() -> Unit) : AbstractTra
207173
invokeOnCompletion { cause ->
208174
logJobCompletion("Message sending", cause)
209175
if (cause is CancellationException) {
210-
readingJob?.cancel(cause)
176+
readChannel.cancel(cause)
211177
}
212178
}
213179
}
@@ -255,23 +221,22 @@ public class StdioServerTransport(block: Configuration.() -> Unit) : AbstractTra
255221

256222
withContext(NonCancellable) {
257223
writeChannel.close()
258-
sendingJob?.cancelAndJoin()
259224

260225
runCatching {
261226
source.close()
262227
}.onFailure { logger.warn(it) { "Failed to close stdin" } }
263228

264-
readingJob?.cancel()
265229
readChannel.close()
266230

267-
processingJob?.cancelAndJoin()
268-
269231
readBuffer.clear()
270232
runCatching {
271233
sink.flush()
272234
sink.close()
273235
}.onFailure { logger.warn(it) { "Failed to close stdout" } }
274236

237+
scope.cancel()
238+
scope.coroutineContext[Job]?.join()
239+
275240
invokeOnCloseCallback()
276241
}
277242
}

kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StdioServerTransportTest.kt

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import io.kotest.assertions.throwables.shouldThrow
55
import io.kotest.assertions.withClue
66
import io.kotest.matchers.collections.shouldContain
77
import io.kotest.matchers.shouldBe
8-
import io.ktor.client.utils.clientDispatcher
9-
import io.ktor.utils.io.InternalAPI
108
import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer
119
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
1210
import io.modelcontextprotocol.kotlin.sdk.types.InitializedNotification
@@ -19,6 +17,7 @@ import kotlinx.coroutines.CompletableDeferred
1917
import kotlinx.coroutines.CoroutineScope
2018
import kotlinx.coroutines.Dispatchers
2119
import kotlinx.coroutines.TimeoutCancellationException
20+
import kotlinx.coroutines.channels.BufferOverflow
2221
import kotlinx.coroutines.channels.Channel
2322
import kotlinx.coroutines.withTimeout
2423
import kotlinx.io.Buffer
@@ -74,26 +73,23 @@ class StdioServerTransportTest {
7473
printOutput = output.asSink().buffered()
7574
}
7675

77-
@OptIn(InternalAPI::class)
7876
@Test
7977
fun `should construct with builder`() = runIntegrationTest {
8078
val received = CompletableDeferred<JSONRPCMessage>()
8179

82-
val ioDispatcher = Dispatchers.IO.limitedParallelism(4)
83-
8480
// Set every configuration parameter explicitly with non-default values,
8581
// then verify a message round-trips correctly.
86-
val server = StdioServerTransport {
87-
source = bufferedInput
88-
sink = printOutput
89-
readBufferSize = 16L // non-default: smaller read chunk
90-
readingJobDispatcher = ioDispatcher // non-default: limited parallelism
91-
writingJobDispatcher = ioDispatcher // non-default: limited parallelism
92-
processingJobDispatcher = Dispatchers.clientDispatcher(2, "Worker") // non-default
93-
readChannelBufferSize = Channel.BUFFERED // non-default: bounded
94-
writeChannelBufferSize = Channel.BUFFERED // non-default: bounded
95-
coroutineScope = CoroutineScope(Dispatchers.Default) // non-default: parent scope provided
96-
}
82+
val server = StdioServerTransport(
83+
source = bufferedInput,
84+
sink = printOutput,
85+
readBufferSize = 16L, // non-default: smaller read chunk
86+
readingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Read"), // non-default: limited parallelism
87+
writingJobDispatcher = Dispatchers.IO.limitedParallelism(4, "Write"), // non-default: limited parallelism
88+
processingJobDispatcher = Dispatchers.IO.limitedParallelism(2, name = "Worker"), // non-default
89+
readChannel = Channel(capacity = 8, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded
90+
writeChannel = Channel(capacity = 16, onBufferOverflow = BufferOverflow.SUSPEND), // non-default: bounded
91+
coroutineScope = CoroutineScope(Dispatchers.Default), // non-default: parent scope provided
92+
)
9793
server.onError { throw it }
9894
server.onMessage { received.complete(it) }
9995

0 commit comments

Comments
 (0)