diff --git a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt index b117dbfb4..9d512361e 100644 --- a/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt +++ b/integration-test/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientTest.kt @@ -417,13 +417,14 @@ class ClientTest { clientInfo = Implementation(name = "test client without capability", version = "1.0"), options = ClientOptions( capabilities = ClientCapabilities(), - // enforceStrictCapabilities = true // TODO() ), ) - clientWithoutCapability.connect(clientTransport) - // Using the same transport pair might not be realistic - in a real scenario you'd create another pair. - // Adjust if necessary. + val (clientTransport2, serverTransport2) = InMemoryTransport.createLinkedPair() + listOf( + launch { clientWithoutCapability.connect(clientTransport2) }, + launch { server.createSession(serverTransport2) }, + ).joinAll() // This should fail val ex = assertFailsWith { diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSessionInitializeTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSessionInitializeTest.kt new file mode 100644 index 000000000..5031a9c06 --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSessionInitializeTest.kt @@ -0,0 +1,152 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse +import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.types.RPCError +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.test.runTest +import kotlinx.coroutines.withContext +import org.junit.jupiter.api.Test +import java.util.concurrent.CopyOnWriteArrayList +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue + +class ServerSessionInitializeTest { + + private fun createSession(): ServerSession = ServerSession( + serverInfo = Implementation(name = "test-server", version = "1.0"), + options = ServerOptions(capabilities = ServerCapabilities()), + instructions = null, + ) + + private fun createInitializeRequest(clientName: String = "test-client"): InitializeRequest = InitializeRequest( + InitializeRequestParams( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ClientCapabilities(), + clientInfo = Implementation(name = clientName, version = "1.0"), + ), + ) + + @Test + fun `should handle first initialize request successfully`() = runTest { + val session = createSession() + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + assertNull(session.clientCapabilities) + assertNull(session.clientVersion) + + val responseDone = CompletableDeferred() + clientTransport.onMessage { message -> + if (message is JSONRPCResponse) { + responseDone.complete(message) + } + } + + session.connect(serverTransport) + clientTransport.send(createInitializeRequest().toJSON()) + + val response = responseDone.await() + assertNotNull(response.result) + assertNotNull(session.clientCapabilities) + assertEquals("test-client", session.clientVersion?.name) + } + + @Test + fun `should reject duplicate initialize request`() = runTest { + val session = createSession() + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val responses = CopyOnWriteArrayList() + val secondResponseDone = CompletableDeferred() + + clientTransport.onMessage { message -> + when (message) { + is JSONRPCResponse, is JSONRPCError -> { + responses.add(message) + if (responses.size == 2) secondResponseDone.complete(Unit) + } + + else -> {} + } + } + + session.connect(serverTransport) + + // First initialize should succeed + clientTransport.send(createInitializeRequest(clientName = "first-client").toJSON()) + + // Second initialize should be rejected + clientTransport.send(createInitializeRequest(clientName = "second-client").toJSON()) + + secondResponseDone.await() + + assertEquals(2, responses.size) + assertTrue(responses[0] is JSONRPCResponse, "First response should be success") + assertTrue(responses[1] is JSONRPCError, "Second response should be error") + assertEquals(RPCError.ErrorCode.INVALID_REQUEST, (responses[1] as JSONRPCError).error.code) + + // Capabilities still reflect the first client, not overwritten + assertEquals("first-client", session.clientVersion?.name) + } + + @Test + fun `should reject concurrent initialize requests - only first succeeds`() = runTest { + val session = createSession() + val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair() + + val n = 10 + val allResponsesDone = CompletableDeferred() + val successes = CopyOnWriteArrayList() + val errors = CopyOnWriteArrayList() + + clientTransport.onMessage { message -> + when (message) { + is JSONRPCResponse -> successes.add(message) + is JSONRPCError -> errors.add(message) + else -> {} + } + if (successes.size + errors.size == n) { + allResponsesDone.complete(Unit) + } + } + + session.connect(serverTransport) + + // Use Dispatchers.Default for true parallelism on JVM + withContext(Dispatchers.Default) { + val barrier = CompletableDeferred() + val jobs = (1..n).map { i -> + launch { + barrier.await() + clientTransport.send( + createInitializeRequest(clientName = "client-$i").toJSON(), + ) + } + } + barrier.complete(Unit) + jobs.joinAll() + } + + allResponsesDone.await() + + assertEquals(1, successes.size, "Exactly one initialize should succeed") + assertEquals(n - 1, errors.size, "All other initializes should be rejected") + errors.forEach { error -> + assertEquals(RPCError.ErrorCode.INVALID_REQUEST, error.error.code) + } + } +} diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt index 69d20d154..a8a34aff3 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/ServerSession.kt @@ -20,8 +20,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListRootsRequest import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification +import io.modelcontextprotocol.kotlin.sdk.types.McpException import io.modelcontextprotocol.kotlin.sdk.types.Method import io.modelcontextprotocol.kotlin.sdk.types.Method.Defined +import io.modelcontextprotocol.kotlin.sdk.types.RPCError import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotification import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS @@ -52,17 +54,18 @@ public open class ServerSession( private var _onClose: () -> Unit = {} + private val _clientCapabilities: AtomicRef = atomic(null) + private val _clientVersion: AtomicRef = atomic(null) + /** * The client's reported capabilities after initialization. */ - public var clientCapabilities: ClientCapabilities? = null - private set + public val clientCapabilities: ClientCapabilities? get() = _clientCapabilities.value /** * The client's version information after initialization. */ - public var clientVersion: Implementation? = null - private set + public val clientVersion: Implementation? get() = _clientVersion.value /** * The capabilities supported by the server, related to the session. @@ -286,9 +289,15 @@ public open class ServerSession( } private fun handleInitialize(request: InitializeRequest): InitializeResult { + if (!_clientCapabilities.compareAndSet(null, request.params.capabilities)) { + throw McpException( + code = RPCError.ErrorCode.INVALID_REQUEST, + message = "Server already initialized", + ) + } + logger.debug { "Handling initialization request from client" } - clientCapabilities = request.params.capabilities - clientVersion = request.params.clientInfo + _clientVersion.value = request.params.clientInfo val requestedVersion = request.params.protocolVersion val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) {