Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.CreateMessageResult
import io.modelcontextprotocol.kotlin.sdk.types.DoubleSchema
import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestFormParams
import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestURLParams
import io.modelcontextprotocol.kotlin.sdk.types.ElicitResult
import io.modelcontextprotocol.kotlin.sdk.types.ElicitationCompleteNotification
import io.modelcontextprotocol.kotlin.sdk.types.ElicitationCompleteNotificationParams
import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest
Expand Down Expand Up @@ -46,6 +49,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.Tool
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
import io.modelcontextprotocol.kotlin.sdk.types.UntitledMultiSelectEnumSchema
import io.modelcontextprotocol.kotlin.sdk.types.UntitledSingleSelectEnumSchema
import io.modelcontextprotocol.kotlin.sdk.types.UrlElicitationRequiredException
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.cancel
Expand Down Expand Up @@ -1251,6 +1255,189 @@ class ClientTest {
client.close()
}

// ── URL-mode elicitation (SEP-1036) ─────────────────────────────────

@Test
fun `should handle URL mode elicitation end-to-end`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
elicitation = ClientCapabilities.Elicitation(url = EmptyJsonObject),
),
),
)

val elicitationId = "550e8400-e29b-41d4-a716-446655440000"
val url = "https://oauth.example.com/authorize"

client.setElicitationHandler { request ->
val params = assertIs<ElicitRequestURLParams>(request.params)
assertEquals(elicitationId, params.elicitationId)
assertEquals(url, params.url)
ElicitResult(action = ElicitResult.Action.Accept)
}

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
val server = Server(
serverInfo = Implementation(name = "test server", version = "1.0"),
options = ServerOptions(capabilities = ServerCapabilities()),
)

val serverSessionResult = CompletableDeferred<ServerSession>()
listOf(
launch { client.connect(clientTransport) },
launch { serverSessionResult.complete(server.createSession(serverTransport)) },
).joinAll()
val serverSession = serverSessionResult.await()

val result = serverSession.createElicitation(
message = "Authorize access to continue",
elicitationId = elicitationId,
url = url,
)

assertEquals(ElicitResult.Action.Accept, result.action)
assertNull(result.content)

client.close()
}

@Test
fun `should reject URL mode elicitation when client supports only form mode`() = runTest {
val (client, serverSession) = setupElicitationPair {
ElicitResult(action = ElicitResult.Action.Accept)
}

val exception = assertFailsWith<IllegalArgumentException> {
serverSession.createElicitation(
message = "Authorize",
elicitationId = "id-1",
url = "https://example.com/auth",
)
}
assertTrue(exception.message!!.contains("elicitation.url"))

client.close()
}

@Test
fun `should deliver elicitation complete notification to client`() = runTest {
val received = CompletableDeferred<ElicitationCompleteNotification>()
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
elicitation = ClientCapabilities.Elicitation(url = EmptyJsonObject),
),
),
)
client.setElicitationCompleteHandler { received.complete(it) }

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
val server = Server(
serverInfo = Implementation(name = "test server", version = "1.0"),
options = ServerOptions(capabilities = ServerCapabilities()),
)

val serverSessionResult = CompletableDeferred<ServerSession>()
listOf(
launch { client.connect(clientTransport) },
launch { serverSessionResult.complete(server.createSession(serverTransport)) },
).joinAll()
val serverSession = serverSessionResult.await()

val elicitationId = "complete-id-1"
serverSession.sendElicitationComplete(
ElicitationCompleteNotification(ElicitationCompleteNotificationParams(elicitationId = elicitationId)),
)

val notification = received.await()
assertEquals(elicitationId, notification.params.elicitationId)

client.close()
}

@Test
fun `should reject elicitation complete when client supports only form mode`() = runTest {
val (client, serverSession) = setupElicitationPair {
ElicitResult(action = ElicitResult.Action.Accept)
}

val exception = assertFailsWith<IllegalArgumentException> {
serverSession.sendElicitationComplete(
ElicitationCompleteNotification(ElicitationCompleteNotificationParams(elicitationId = "id-1")),
)
}
assertTrue(exception.message!!.contains("elicitation.url"))

client.close()
}

@Test
fun `setElicitationCompleteHandler should require url capability`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
elicitation = ClientCapabilities.Elicitation(),
),
),
)

assertFailsWith<IllegalStateException> {
client.setElicitationCompleteHandler { }
}
}

@Test
fun `should surface URL elicitation required error to client as typed exception`() = runTest {
val client = Client(
Implementation(name = "test client", version = "1.0"),
ClientOptions(
capabilities = ClientCapabilities(
elicitation = ClientCapabilities.Elicitation(url = EmptyJsonObject),
),
),
)

val elicitationId = "auth-required-1"
val url = "https://oauth.example.com/authorize"

val (clientTransport, serverTransport) = InMemoryTransport.createLinkedPair()
val server = Server(
serverInfo = Implementation(name = "test server", version = "1.0"),
options = ServerOptions(capabilities = ServerCapabilities(tools = ServerCapabilities.Tools(true))),
)
server.addTool("needs-auth", "Requires URL elicitation") {
throw UrlElicitationRequiredException(
listOf(
ElicitRequestURLParams(
message = "Authorize to continue",
elicitationId = elicitationId,
url = url,
),
),
)
}

val serverSessionResult = CompletableDeferred<ServerSession>()
listOf(
launch { client.connect(clientTransport) },
launch { serverSessionResult.complete(server.createSession(serverTransport)) },
).joinAll()
serverSessionResult.await()

val exception = assertFailsWith<UrlElicitationRequiredException> {
client.callTool(name = "needs-auth", arguments = emptyMap())
}
val elicitation = exception.elicitations.single()
assertEquals(elicitationId, elicitation.elicitationId)
assertEquals(url, elicitation.url)

client.close()
}

private fun defaultsTestSchema(): ElicitRequestParams.RequestedSchema = ElicitRequestParams.RequestedSchema(
properties = mapOf(
"name" to StringSchema(description = "User name", default = "John Doe"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.ElicitationCompleteNotification
import io.modelcontextprotocol.kotlin.sdk.types.ElicitationCompleteNotificationParams
import io.modelcontextprotocol.kotlin.sdk.types.EmptyJsonObject
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.ListRootsRequest
Expand Down Expand Up @@ -61,6 +62,7 @@ class ClientConnectionTest : AbstractServerFeaturesTest() {

override fun getClientCapabilities(): ClientCapabilities = ClientCapabilities(
roots = ClientCapabilities.Roots(listChanged = true),
elicitation = ClientCapabilities.Elicitation(url = EmptyJsonObject),
)

private val sampleRoots = listOf(Root("file:///project", "Project Root"))
Expand Down
1 change: 1 addition & 0 deletions kotlin-sdk-client/api/kotlin-sdk-client.api
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class io/modelcontextprotocol/kotlin/sdk/client/Client : io/modelcontextp
public final fun removeRoot (Ljava/lang/String;)Z
public final fun removeRoots (Ljava/util/List;)I
public final fun sendRootsListChanged (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public final fun setElicitationCompleteHandler (Lkotlin/jvm/functions/Function1;)V
public final fun setElicitationHandler (Lkotlin/jvm/functions/Function1;)V
public final fun setLoggingLevel (Lio/modelcontextprotocol/kotlin/sdk/types/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static synthetic fun setLoggingLevel$default (Lio/modelcontextprotocol/kotlin/sdk/client/Client;Lio/modelcontextprotocol/kotlin/sdk/types/LoggingLevel;Lio/modelcontextprotocol/kotlin/sdk/shared/RequestOptions;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.DoubleSchema
import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequest
import io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestFormParams
import io.modelcontextprotocol.kotlin.sdk.types.ElicitResult
import io.modelcontextprotocol.kotlin.sdk.types.ElicitationCompleteNotification
import io.modelcontextprotocol.kotlin.sdk.types.EmptyResult
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult
Expand Down Expand Up @@ -63,13 +64,15 @@ import io.modelcontextprotocol.kotlin.sdk.types.TitledSingleSelectEnumSchema
import io.modelcontextprotocol.kotlin.sdk.types.UnsubscribeRequest
import io.modelcontextprotocol.kotlin.sdk.types.UntitledMultiSelectEnumSchema
import io.modelcontextprotocol.kotlin.sdk.types.UntitledSingleSelectEnumSchema
import io.modelcontextprotocol.kotlin.sdk.types.supportsUrl
import io.modelcontextprotocol.kotlin.sdk.types.toJson
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.getAndUpdate
import kotlinx.atomicfu.update
import kotlinx.collections.immutable.minus
import kotlinx.collections.immutable.persistentMapOf
import kotlinx.collections.immutable.toPersistentSet
import kotlinx.coroutines.CompletableDeferred
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
Expand Down Expand Up @@ -643,9 +646,18 @@ public open class Client(private val clientInfo: Implementation, options: Client
/**
* Sets the elicitation handler.
*
* When the handler returns [ElicitResult.Action.Accept], any properties missing from
* The handler receives both form-mode ([ElicitRequestFormParams]) and URL-mode
* ([io.modelcontextprotocol.kotlin.sdk.types.ElicitRequestURLParams]) requests;
* branch on `request.params` to tell them apart. For URL mode,
* the host application must obtain explicit user consent and display the target domain before
* navigating — the SDK never opens or validates the URL — and should return
* [ElicitResult.Action.Decline] or [ElicitResult.Action.Cancel] when it cannot or will not proceed.
* A URL-mode [ElicitResult.Action.Accept] only signals consent; the outcome arrives out-of-band via
* [setElicitationCompleteHandler].
*
* When a form-mode handler returns [ElicitResult.Action.Accept], any properties missing from
* [ElicitResult.content] are automatically populated with default values defined in the
* elicitation schema.
* elicitation schema. URL-mode responses carry no content.
*
* @param handler The elicitation handler.
* @throws IllegalStateException if the client does not support elicitation.
Expand All @@ -663,6 +675,35 @@ public open class Client(private val clientInfo: Implementation, options: Client
}
}

/**
* Sets the handler invoked when the server reports that a URL-mode elicitation has completed.
*
* The handler is called for every `notifications/elicitation/complete` notification. Because the
* server only sends this for an out-of-band (URL-mode) interaction, the client must support url-mode
* elicitation. The client is responsible for correlating the notification's `elicitationId` with a
* pending elicitation, ignoring unknown or already-completed identifiers, and providing a manual way
* to continue if a notification never arrives.
*
* @param handler Invoked with each completion notification.
* @throws IllegalStateException if the client does not support url-mode elicitation.
*/
public fun setElicitationCompleteHandler(handler: (ElicitationCompleteNotification) -> Unit) {
check(capabilities.elicitation.supportsUrl) {
logger.error {
"Failed to set elicitation-complete handler: client does not support url-mode elicitation"
}
"Client does not support url-mode elicitation."
}
logger.info { "Setting the elicitation-complete handler" }

setNotificationHandler<ElicitationCompleteNotification>(
Method.Defined.NotificationsElicitationComplete,
) { notification ->
handler(notification)
CompletableDeferred(Unit)
}
}

// --- Internal Handlers ---

private fun applyElicitationDefaults(request: ElicitRequest, result: ElicitResult): ElicitResult {
Expand Down
40 changes: 39 additions & 1 deletion kotlin-sdk-core/api/kotlin-sdk-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,10 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/CancelledNotificatio
public final fun serializer ()Lkotlinx/serialization/KSerializer;
}

public final class io/modelcontextprotocol/kotlin/sdk/types/CapabilitiesKt {
public static final fun getSupportsUrl (Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Elicitation;)Z
}

public final class io/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/ClientCapabilities$Companion;
public fun <init> ()V
Expand Down Expand Up @@ -3128,7 +3132,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/Logging_dslKt {
public abstract interface annotation class io/modelcontextprotocol/kotlin/sdk/types/McpDsl : java/lang/annotation/Annotation {
}

public final class io/modelcontextprotocol/kotlin/sdk/types/McpException : java/lang/Exception {
public class io/modelcontextprotocol/kotlin/sdk/types/McpException : java/lang/Exception {
public fun <init> (I)V
public fun <init> (ILjava/lang/String;)V
public fun <init> (ILjava/lang/String;Lkotlinx/serialization/json/JsonElement;)V
Expand Down Expand Up @@ -3744,6 +3748,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/RPCError$ErrorCode {
public static final field PARSE_ERROR I
public static final field REQUEST_TIMEOUT I
public static final field RESOURCE_NOT_FOUND I
public static final field URL_ELICITATION_REQUIRED I
}

public final class io/modelcontextprotocol/kotlin/sdk/types/ReadResourceRequest : io/modelcontextprotocol/kotlin/sdk/types/ClientRequest {
Expand Down Expand Up @@ -5814,6 +5819,39 @@ public final class io/modelcontextprotocol/kotlin/sdk/types/UntitledSingleSelect
public final fun serializer ()Lkotlinx/serialization/KSerializer;
}

public final class io/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData$Companion;
public fun <init> (Ljava/util/List;)V
public final fun component1 ()Ljava/util/List;
public final fun copy (Ljava/util/List;)Lio/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData;
public static synthetic fun copy$default (Lio/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData;Ljava/util/List;ILjava/lang/Object;)Lio/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData;
public fun equals (Ljava/lang/Object;)Z
public final fun getElicitations ()Ljava/util/List;
public fun hashCode ()I
public fun toString ()Ljava/lang/String;
}

public final synthetic class io/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData$$serializer : kotlinx/serialization/internal/GeneratedSerializer {
public static final field INSTANCE Lio/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData$$serializer;
public final fun childSerializers ()[Lkotlinx/serialization/KSerializer;
public final fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Lio/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData;
public synthetic fun deserialize (Lkotlinx/serialization/encoding/Decoder;)Ljava/lang/Object;
public final fun getDescriptor ()Lkotlinx/serialization/descriptors/SerialDescriptor;
public final fun serialize (Lkotlinx/serialization/encoding/Encoder;Lio/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData;)V
public synthetic fun serialize (Lkotlinx/serialization/encoding/Encoder;Ljava/lang/Object;)V
public fun typeParametersSerializers ()[Lkotlinx/serialization/KSerializer;
}

public final class io/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredData$Companion {
public final fun serializer ()Lkotlinx/serialization/KSerializer;
}

public final class io/modelcontextprotocol/kotlin/sdk/types/UrlElicitationRequiredException : io/modelcontextprotocol/kotlin/sdk/types/McpException {
public fun <init> (Ljava/util/List;Ljava/lang/String;)V
public synthetic fun <init> (Ljava/util/List;Ljava/lang/String;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getElicitations ()Ljava/util/List;
}

public abstract interface class io/modelcontextprotocol/kotlin/sdk/types/WithMeta {
public static final field Companion Lio/modelcontextprotocol/kotlin/sdk/types/WithMeta$Companion;
public abstract fun getMeta ()Lkotlinx/serialization/json/JsonObject;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio
handler(response, null)
} else {
checkNotNull(error)
val mcpException = McpException(
val mcpException = McpException.fromError(
code = error.error.code,
message = error.error.message,
data = error.error.data,
Expand Down
Loading
Loading