Skip to content

Commit 185bc6a

Browse files
committed
fix: reject duplicate initialize requests in ServerSession
1 parent bd4cf0c commit 185bc6a

3 files changed

Lines changed: 172 additions & 10 deletions

File tree

integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,13 +417,14 @@ class ClientTest {
417417
clientInfo = Implementation(name = "test client without capability", version = "1.0"),
418418
options = ClientOptions(
419419
capabilities = ClientCapabilities(),
420-
// enforceStrictCapabilities = true // TODO()
421420
),
422421
)
423422

424-
clientWithoutCapability.connect(clientTransport)
425-
// Using the same transport pair might not be realistic - in a real scenario you'd create another pair.
426-
// Adjust if necessary.
423+
val (clientTransport2, serverTransport2) = InMemoryTransport.createLinkedPair()
424+
listOf(
425+
launch { clientWithoutCapability.connect(clientTransport2) },
426+
launch { server.createSession(serverTransport2) },
427+
).joinAll()
427428

428429
// This should fail
429430
val ex = assertFailsWith<IllegalStateException> {
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
package io.modelcontextprotocol.kotlin.sdk.server
2+
3+
import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport
4+
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
5+
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
6+
import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest
7+
import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams
8+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError
9+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
10+
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse
11+
import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION
12+
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
13+
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
14+
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
15+
import kotlinx.coroutines.CompletableDeferred
16+
import kotlinx.coroutines.Dispatchers
17+
import kotlinx.coroutines.joinAll
18+
import kotlinx.coroutines.launch
19+
import kotlinx.coroutines.test.runTest
20+
import kotlinx.coroutines.withContext
21+
import org.junit.jupiter.api.Test
22+
import java.util.concurrent.CopyOnWriteArrayList
23+
import kotlin.test.assertEquals
24+
import kotlin.test.assertNotNull
25+
import kotlin.test.assertNull
26+
import kotlin.test.assertTrue
27+
28+
class ServerSessionInitializeTest {
29+
30+
private fun createSession(): ServerSession = ServerSession(
31+
serverInfo = Implementation(name = "test-server", version = "1.0"),
32+
options = ServerOptions(capabilities = ServerCapabilities()),
33+
instructions = null,
34+
)
35+
36+
private fun createInitializeRequest(clientName: String = "test-client"): InitializeRequest = InitializeRequest(
37+
InitializeRequestParams(
38+
protocolVersion = LATEST_PROTOCOL_VERSION,
39+
capabilities = ClientCapabilities(),
40+
clientInfo = Implementation(name = clientName, version = "1.0"),
41+
),
42+
)
43+
44+
@Test
45+
fun `should handle first initialize request successfully`() = runTest {
46+
val session = createSession()
47+
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
48+
49+
assertNull(session.clientCapabilities)
50+
assertNull(session.clientVersion)
51+
52+
val responseDone = CompletableDeferred<JSONRPCResponse>()
53+
clientTransport.onMessage { message ->
54+
if (message is JSONRPCResponse) {
55+
responseDone.complete(message)
56+
}
57+
}
58+
59+
session.connect(serverTransport)
60+
clientTransport.send(createInitializeRequest().toJSON())
61+
62+
val response = responseDone.await()
63+
assertNotNull(response.result)
64+
assertNotNull(session.clientCapabilities)
65+
assertEquals("test-client", session.clientVersion?.name)
66+
}
67+
68+
@Test
69+
fun `should reject duplicate initialize request`() = runTest {
70+
val session = createSession()
71+
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
72+
73+
val responses = CopyOnWriteArrayList<JSONRPCMessage>()
74+
val secondResponseDone = CompletableDeferred<Unit>()
75+
76+
clientTransport.onMessage { message ->
77+
when (message) {
78+
is JSONRPCResponse, is JSONRPCError -> {
79+
responses.add(message)
80+
if (responses.size == 2) secondResponseDone.complete(Unit)
81+
}
82+
83+
else -> {}
84+
}
85+
}
86+
87+
session.connect(serverTransport)
88+
89+
// First initialize should succeed
90+
clientTransport.send(createInitializeRequest(clientName = "first-client").toJSON())
91+
92+
// Second initialize should be rejected
93+
clientTransport.send(createInitializeRequest(clientName = "second-client").toJSON())
94+
95+
secondResponseDone.await()
96+
97+
assertEquals(2, responses.size)
98+
assertTrue(responses[0] is JSONRPCResponse, "First response should be success")
99+
assertTrue(responses[1] is JSONRPCError, "Second response should be error")
100+
assertEquals(RPCError.ErrorCode.INVALID_REQUEST, (responses[1] as JSONRPCError).error.code)
101+
102+
// Capabilities still reflect the first client, not overwritten
103+
assertEquals("first-client", session.clientVersion?.name)
104+
}
105+
106+
@Test
107+
fun `should reject concurrent initialize requests - only first succeeds`() = runTest {
108+
val session = createSession()
109+
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
110+
111+
val n = 10
112+
val allResponsesDone = CompletableDeferred<Unit>()
113+
val successes = CopyOnWriteArrayList<JSONRPCResponse>()
114+
val errors = CopyOnWriteArrayList<JSONRPCError>()
115+
116+
clientTransport.onMessage { message ->
117+
when (message) {
118+
is JSONRPCResponse -> successes.add(message)
119+
is JSONRPCError -> errors.add(message)
120+
else -> {}
121+
}
122+
if (successes.size + errors.size == n) {
123+
allResponsesDone.complete(Unit)
124+
}
125+
}
126+
127+
session.connect(serverTransport)
128+
129+
// Use Dispatchers.Default for true parallelism on JVM
130+
withContext(Dispatchers.Default) {
131+
val barrier = CompletableDeferred<Unit>()
132+
val jobs = (1..n).map { i ->
133+
launch {
134+
barrier.await()
135+
clientTransport.send(
136+
createInitializeRequest(clientName = "client-$i").toJSON(),
137+
)
138+
}
139+
}
140+
barrier.complete(Unit)
141+
jobs.joinAll()
142+
}
143+
144+
allResponsesDone.await()
145+
146+
assertEquals(1, successes.size, "Exactly one initialize should succeed")
147+
assertEquals(n - 1, errors.size, "All other initializes should be rejected")
148+
errors.forEach { error ->
149+
assertEquals(RPCError.ErrorCode.INVALID_REQUEST, error.error.code)
150+
}
151+
}
152+
}

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListRootsRequest
2020
import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult
2121
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
2222
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
23+
import io.modelcontextprotocol.kotlin.sdk.types.McpException
2324
import io.modelcontextprotocol.kotlin.sdk.types.Method
2425
import io.modelcontextprotocol.kotlin.sdk.types.Method.Defined
26+
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
2527
import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta
2628
import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotification
2729
import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS
@@ -52,17 +54,18 @@ public open class ServerSession(
5254

5355
private var _onClose: () -> Unit = {}
5456

57+
private val _clientCapabilities: AtomicRef<ClientCapabilities?> = atomic(null)
58+
private val _clientVersion: AtomicRef<Implementation?> = atomic(null)
59+
5560
/**
5661
* The client's reported capabilities after initialization.
5762
*/
58-
public var clientCapabilities: ClientCapabilities? = null
59-
private set
63+
public val clientCapabilities: ClientCapabilities? get() = _clientCapabilities.value
6064

6165
/**
6266
* The client's version information after initialization.
6367
*/
64-
public var clientVersion: Implementation? = null
65-
private set
68+
public val clientVersion: Implementation? get() = _clientVersion.value
6669

6770
/**
6871
* The capabilities supported by the server, related to the session.
@@ -286,9 +289,15 @@ public open class ServerSession(
286289
}
287290

288291
private fun handleInitialize(request: InitializeRequest): InitializeResult {
292+
if (!_clientCapabilities.compareAndSet(null, request.params.capabilities)) {
293+
throw McpException(
294+
code = RPCError.ErrorCode.INVALID_REQUEST,
295+
message = "Server already initialized",
296+
)
297+
}
298+
289299
logger.debug { "Handling initialization request from client" }
290-
clientCapabilities = request.params.capabilities
291-
clientVersion = request.params.clientInfo
300+
_clientVersion.value = request.params.clientInfo
292301

293302
val requestedVersion = request.params.protocolVersion
294303
val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) {

0 commit comments

Comments
 (0)