Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,14 @@ class ClientTest {
clientInfo = Implementation(name = "test client without capability", version = "1.0"),
options = ClientOptions(
capabilities = ClientCapabilities(),
// enforceStrictCapabilities = true // TODO()
),
)

clientWithoutCapability.connect(clientTransport)
// Using the same transport pair might not be realistic - in a real scenario you'd create another pair.
// Adjust if necessary.
val (clientTransport2, serverTransport2) = InMemoryTransport.createLinkedPair()
listOf(
launch { clientWithoutCapability.connect(clientTransport2) },
launch { server.createSession(serverTransport2) },
).joinAll()

// This should fail
val ex = assertFailsWith<IllegalStateException> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package io.modelcontextprotocol.kotlin.sdk.server

import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest
import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCError
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCResponse
import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.joinAll
import kotlinx.coroutines.launch
import kotlinx.coroutines.test.runTest
import kotlinx.coroutines.withContext
import org.junit.jupiter.api.Test
import java.util.concurrent.CopyOnWriteArrayList
import kotlin.test.assertEquals
import kotlin.test.assertNotNull
import kotlin.test.assertNull
import kotlin.test.assertTrue

class ServerSessionInitializeTest {

private fun createSession(): ServerSession = ServerSession(
serverInfo = Implementation(name = "test-server", version = "1.0"),
options = ServerOptions(capabilities = ServerCapabilities()),
instructions = null,
)

private fun createInitializeRequest(clientName: String = "test-client"): InitializeRequest = InitializeRequest(
InitializeRequestParams(
protocolVersion = LATEST_PROTOCOL_VERSION,
capabilities = ClientCapabilities(),
clientInfo = Implementation(name = clientName, version = "1.0"),
),
)

@Test
fun `should handle first initialize request successfully`() = runTest {
val session = createSession()
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

assertNull(session.clientCapabilities)
assertNull(session.clientVersion)

val responseDone = CompletableDeferred<JSONRPCResponse>()
clientTransport.onMessage { message ->
if (message is JSONRPCResponse) {
responseDone.complete(message)
}
}

session.connect(serverTransport)
clientTransport.send(createInitializeRequest().toJSON())

val response = responseDone.await()
assertNotNull(response.result)
assertNotNull(session.clientCapabilities)
assertEquals("test-client", session.clientVersion?.name)
}

@Test
fun `should reject duplicate initialize request`() = runTest {
val session = createSession()
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

val responses = CopyOnWriteArrayList<JSONRPCMessage>()
val secondResponseDone = CompletableDeferred<Unit>()

clientTransport.onMessage { message ->
when (message) {
is JSONRPCResponse, is JSONRPCError -> {
responses.add(message)
if (responses.size == 2) secondResponseDone.complete(Unit)
}

else -> {}
}
}

session.connect(serverTransport)

// First initialize should succeed
clientTransport.send(createInitializeRequest(clientName = "first-client").toJSON())

// Second initialize should be rejected
clientTransport.send(createInitializeRequest(clientName = "second-client").toJSON())

secondResponseDone.await()

assertEquals(2, responses.size)
assertTrue(responses[0] is JSONRPCResponse, "First response should be success")
assertTrue(responses[1] is JSONRPCError, "Second response should be error")
assertEquals(RPCError.ErrorCode.INVALID_REQUEST, (responses[1] as JSONRPCError).error.code)

// Capabilities still reflect the first client, not overwritten
assertEquals("first-client", session.clientVersion?.name)
}
Comment thread
devcrocod marked this conversation as resolved.

@Test
fun `should reject concurrent initialize requests - only first succeeds`() = runTest {
val session = createSession()
val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()

val n = 10
val allResponsesDone = CompletableDeferred<Unit>()
val successes = CopyOnWriteArrayList<JSONRPCResponse>()
val errors = CopyOnWriteArrayList<JSONRPCError>()

clientTransport.onMessage { message ->
when (message) {
is JSONRPCResponse -> successes.add(message)
is JSONRPCError -> errors.add(message)
else -> {}
}
if (successes.size + errors.size == n) {
allResponsesDone.complete(Unit)
}
}

session.connect(serverTransport)

// Use Dispatchers.Default for true parallelism on JVM
withContext(Dispatchers.Default) {
val barrier = CompletableDeferred<Unit>()
val jobs = (1..n).map { i ->
launch {
barrier.await()
clientTransport.send(
createInitializeRequest(clientName = "client-$i").toJSON(),
)
}
}
barrier.complete(Unit)
jobs.joinAll()
}

allResponsesDone.await()

assertEquals(1, successes.size, "Exactly one initialize should succeed")
assertEquals(n - 1, errors.size, "All other initializes should be rejected")
errors.forEach { error ->
assertEquals(RPCError.ErrorCode.INVALID_REQUEST, error.error.code)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.ListRootsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListRootsResult
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.Method.Defined
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.RequestMeta
import io.modelcontextprotocol.kotlin.sdk.types.ResourceUpdatedNotification
import io.modelcontextprotocol.kotlin.sdk.types.SUPPORTED_PROTOCOL_VERSIONS
Expand Down Expand Up @@ -52,17 +54,18 @@ public open class ServerSession(

private var _onClose: () -> Unit = {}

private val _clientCapabilities: AtomicRef<ClientCapabilities?> = atomic(null)
private val _clientVersion: AtomicRef<Implementation?> = atomic(null)

/**
* The client's reported capabilities after initialization.
*/
public var clientCapabilities: ClientCapabilities? = null
private set
public val clientCapabilities: ClientCapabilities? get() = _clientCapabilities.value

/**
* The client's version information after initialization.
*/
public var clientVersion: Implementation? = null
private set
public val clientVersion: Implementation? get() = _clientVersion.value

/**
* The capabilities supported by the server, related to the session.
Expand Down Expand Up @@ -286,9 +289,15 @@ public open class ServerSession(
}

private fun handleInitialize(request: InitializeRequest): InitializeResult {
if (!_clientCapabilities.compareAndSet(null, request.params.capabilities)) {
throw McpException(
code = RPCError.ErrorCode.INVALID_REQUEST,
message = "Server already initialized",
)
}

logger.debug { "Handling initialization request from client" }
clientCapabilities = request.params.capabilities
clientVersion = request.params.clientInfo
_clientVersion.value = request.params.clientInfo

val requestedVersion = request.params.protocolVersion
val protocolVersion = if (SUPPORTED_PROTOCOL_VERSIONS.contains(requestedVersion)) {
Expand Down
Loading