Skip to content

Commit 4a08959

Browse files
rnettkpavlov
andauthored
feat! Add context to server call handlers (#515)
## Motivation and Context Adds a `ClientConnection` context receiver to server call handlers, which provides access to the client communication channel. This is necessary to allow tools to send notifications, logging, elicitations, etc. Supersedes #144, discussed in #295. ## How Has This Been Tested? With the added integration test. ## Breaking Changes It breaks binary compatibility and breaks source compatibility for tool handlers that already use `this`. If you care a lot about maintaining compatibility, I could add some overloads to ease the transition. But the new methods would likely need to be named differently for resolution. Future compatibility should be easier to maintain - `Context` is an interface (with experimental opt-in subclassing), so new fields can be added without breaking compatibility. Different types could also be introduced for different handles without breaking compatibility, as long as they extend `ClientConnection`. ## Types of changes <!-- What types of changes does your code introduce? Put an `x` in all the boxes that apply: --> - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [x] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Documentation update ## Checklist <!-- Go over all the following points, and put an `x` in all the boxes that apply. --> - [x] I have read the [MCP Documentation](https://modelcontextprotocol.io) - [x] My code follows the repository's style guidelines - [x] New and existing tests pass locally - [x] I have added appropriate error handling - [x] I have added or updated documentation as needed ## Additional context <!-- Add any other context, implementation notes, or design decisions --> --------- Co-authored-by: Konstantin Pavlov <1517853+kpavlov@users.noreply.github.com>
1 parent 947d03d commit 4a08959

15 files changed

Lines changed: 1585 additions & 253 deletions

File tree

integration-test/detekt-baseline-test.xml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
<ID>LongMethod:KotlinServerForTsClientSse.kt:HttpServerTransport$suspend fun handleRequest</ID>
3333
<ID>LongMethod:KotlinServerForTsClientSse.kt:KotlinServerForTsClient$fun createMcpServer: Server</ID>
3434
<ID>LongMethod:KotlinServerForTsClientSse.kt:KotlinServerForTsClient$fun start</ID>
35-
<ID>LongMethod:KotlinTestBase.kt:KotlinTestBase$protected fun setupServer</ID>
3635
<ID>LongMethod:ServerResourcesNotificationSubscribeTest.kt:ServerResourcesNotificationSubscribeTest$@Test fun `should send resource notifications`</ID>
3736
<ID>LongMethod:TsEdgeCasesTestSse.kt:TsEdgeCasesTestSse$@Test @Timeout(30, unit = TimeUnit.SECONDS) fun testComplexConcurrentRequests: Unit</ID>
3837
<ID>MatchingDeclarationName:PromptIntegrationTestSse.kt:SchemaPromptIntegrationTestSse : AbstractPromptIntegrationTest</ID>

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/AbstractServerFeaturesTest.kt

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
package io.modelcontextprotocol.kotlin.sdk.server
22

33
import io.modelcontextprotocol.kotlin.sdk.client.Client
4+
import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions
45
import io.modelcontextprotocol.kotlin.sdk.shared.InMemoryTransport
6+
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
7+
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
8+
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult
59
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
10+
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceResult
611
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
12+
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
13+
import io.modelcontextprotocol.kotlin.sdk.types.TextResourceContents
714
import kotlinx.coroutines.launch
815
import kotlinx.coroutines.runBlocking
916
import org.junit.jupiter.api.BeforeEach
@@ -13,8 +20,31 @@ abstract class AbstractServerFeaturesTest {
1320
protected lateinit var server: Server
1421
protected lateinit var client: Client
1522

23+
protected fun addTool(name: String, block: suspend ClientConnection.() -> Unit) {
24+
server.addTool(name, "Test $name") {
25+
block()
26+
CallToolResult(listOf(TextContent("Success")))
27+
}
28+
}
29+
30+
protected fun addPrompt(name: String, block: suspend ClientConnection.() -> Unit) {
31+
server.addPrompt(name, "Test $name") {
32+
block()
33+
GetPromptResult(messages = emptyList())
34+
}
35+
}
36+
37+
protected fun addResource(uri: String, block: suspend ClientConnection.() -> Unit) {
38+
server.addResource(uri, uri, "Test resource $uri") {
39+
block()
40+
ReadResourceResult(contents = listOf(TextResourceContents(text = "content", uri = uri)))
41+
}
42+
}
43+
1644
abstract fun getServerCapabilities(): ServerCapabilities
1745

46+
protected open fun getClientCapabilities(): ClientCapabilities = ClientCapabilities()
47+
1848
protected open fun getServerInstructionsProvider(): (() -> String)? = null
1949

2050
@BeforeEach
@@ -33,6 +63,9 @@ abstract class AbstractServerFeaturesTest {
3363

3464
client = Client(
3565
clientInfo = Implementation(name = "test client", version = "1.0"),
66+
options = ClientOptions(
67+
capabilities = getClientCapabilities(),
68+
),
3669
)
3770

3871
runBlocking {
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package io.modelcontextprotocol.kotlin.sdk.server
2+
3+
import io.kotest.matchers.collections.shouldBeEmpty
4+
import io.kotest.matchers.collections.shouldHaveSize
5+
import io.kotest.matchers.shouldBe
6+
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest
7+
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
8+
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
9+
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
10+
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams
11+
import io.modelcontextprotocol.kotlin.sdk.types.Method
12+
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
13+
import kotlinx.coroutines.CompletableDeferred
14+
import kotlinx.coroutines.runBlocking
15+
import kotlinx.serialization.json.JsonObject
16+
import kotlinx.serialization.json.JsonPrimitive
17+
import org.junit.jupiter.api.Test
18+
import org.junit.jupiter.params.ParameterizedTest
19+
import org.junit.jupiter.params.provider.EnumSource
20+
21+
class ClientConnectionLoggingTest : AbstractServerFeaturesTest() {
22+
23+
override fun getServerCapabilities(): ServerCapabilities = ServerCapabilities(
24+
tools = ServerCapabilities.Tools(listChanged = true),
25+
logging = JsonObject(emptyMap()),
26+
)
27+
28+
@Test
29+
fun `notification should send logging message to client`(): Unit = runBlocking {
30+
val notificationReceived = CompletableDeferred<LoggingMessageNotification>()
31+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
32+
notificationReceived.complete(it)
33+
CompletableDeferred(Unit)
34+
}
35+
36+
val expectedLevel = LoggingLevel.Info
37+
val expectedData = JsonPrimitive("test-data-sample")
38+
39+
addTool("test-notification") {
40+
notification(
41+
LoggingMessageNotification(
42+
LoggingMessageNotificationParams(
43+
level = expectedLevel,
44+
data = expectedData,
45+
),
46+
),
47+
)
48+
}
49+
50+
client.callTool(CallToolRequest(CallToolRequestParams("test-notification")))
51+
val received = notificationReceived.await()
52+
received.params.level shouldBe expectedLevel
53+
received.params.data shouldBe expectedData
54+
}
55+
56+
@ParameterizedTest
57+
@EnumSource(LoggingLevel::class)
58+
fun `notification should filter logging messages below level`(minLevel: LoggingLevel): Unit = runBlocking {
59+
val receivedMessages = mutableListOf<LoggingMessageNotification>()
60+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
61+
receivedMessages.add(it)
62+
CompletableDeferred(Unit)
63+
}
64+
65+
client.setLoggingLevel(minLevel)
66+
67+
addTool("test-notification-filtered") {
68+
LoggingLevel.entries.forEach { level ->
69+
notification(
70+
LoggingMessageNotification(
71+
LoggingMessageNotificationParams(
72+
level = level,
73+
data = JsonPrimitive(level.name),
74+
),
75+
),
76+
)
77+
}
78+
}
79+
80+
client.callTool(CallToolRequest(CallToolRequestParams("test-notification-filtered")))
81+
82+
val expectedLevels = LoggingLevel.entries.filter { it >= minLevel }
83+
receivedMessages shouldHaveSize expectedLevels.size
84+
receivedMessages.map { it.params.level } shouldBe expectedLevels
85+
}
86+
87+
@ParameterizedTest
88+
@EnumSource(LoggingLevel::class)
89+
fun `sendLoggingMessage should send message at level`(expectedLevel: LoggingLevel): Unit = runBlocking {
90+
val notificationReceived = CompletableDeferred<LoggingMessageNotification>()
91+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
92+
notificationReceived.complete(it)
93+
CompletableDeferred(Unit)
94+
}
95+
96+
val expectedData = JsonObject(mapOf("key" to JsonPrimitive("value")))
97+
98+
addTool("test-logging") {
99+
sendLoggingMessage(
100+
LoggingMessageNotification(
101+
LoggingMessageNotificationParams(
102+
level = expectedLevel,
103+
data = expectedData,
104+
),
105+
),
106+
)
107+
}
108+
109+
client.callTool(CallToolRequest(CallToolRequestParams("test-logging")))
110+
val received = notificationReceived.await()
111+
received.params.level shouldBe expectedLevel
112+
received.params.data shouldBe expectedData
113+
}
114+
115+
@ParameterizedTest
116+
@EnumSource(LoggingLevel::class)
117+
fun `sendLoggingMessage should filter messages below level`(minLevel: LoggingLevel): Unit = runBlocking {
118+
val receivedMessages = mutableListOf<LoggingMessageNotification>()
119+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
120+
receivedMessages.add(it)
121+
CompletableDeferred(Unit)
122+
}
123+
124+
client.setLoggingLevel(minLevel)
125+
126+
addTool("test-logging-level") {
127+
LoggingLevel.entries.forEach { level ->
128+
sendLoggingMessage(
129+
LoggingMessageNotification(
130+
LoggingMessageNotificationParams(
131+
level = level,
132+
data = JsonPrimitive(level.name),
133+
),
134+
),
135+
)
136+
}
137+
}
138+
139+
client.callTool(CallToolRequest(CallToolRequestParams("test-logging-level")))
140+
141+
val expectedLevels = LoggingLevel.entries.filter { it >= minLevel }
142+
receivedMessages shouldHaveSize expectedLevels.size
143+
receivedMessages.map { it.params.level } shouldBe expectedLevels
144+
}
145+
146+
@Test
147+
fun `sendLoggingMessage should send no messages when level is set to highest`(): Unit = runBlocking {
148+
val receivedMessages = mutableListOf<LoggingMessageNotification>()
149+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
150+
receivedMessages.add(it)
151+
CompletableDeferred(Unit)
152+
}
153+
154+
client.setLoggingLevel(LoggingLevel.Emergency)
155+
156+
addTool("test-logging-highest") {
157+
LoggingLevel.entries.dropLast(1).forEach { level ->
158+
sendLoggingMessage(
159+
LoggingMessageNotification(
160+
LoggingMessageNotificationParams(
161+
level = level,
162+
data = JsonPrimitive(level.name),
163+
),
164+
),
165+
)
166+
}
167+
}
168+
169+
client.callTool(CallToolRequest(CallToolRequestParams("test-logging-highest")))
170+
receivedMessages.shouldBeEmpty()
171+
}
172+
}

0 commit comments

Comments
 (0)