Skip to content

Commit 23e9bc1

Browse files
jskjw157kpavlov
authored andcommitted
test(streamable-http): add missing integration tests for pagination, bad request, and logging
- Add cursor-based pagination tests for Prompts, Resources, and Tools with full page traversal until nextCursor is null - Add invalid cursor tests using assertFailsWith (no nested runBlocking) - Add LoggingIntegrationTestStreamableHttp for setLevel and notification tests - Use LoggingLevel.entries instead of values() for allocation-free iteration
1 parent 55afb63 commit 23e9bc1

4 files changed

Lines changed: 295 additions & 0 deletions

File tree

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ import io.kotest.matchers.string.shouldContain
66
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest
77
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams
88
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult
9+
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsRequest
10+
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsResult
911
import io.modelcontextprotocol.kotlin.sdk.types.McpException
12+
import io.modelcontextprotocol.kotlin.sdk.types.Method
13+
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
1014
import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument
1115
import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage
1216
import io.modelcontextprotocol.kotlin.sdk.types.Role
@@ -697,4 +701,55 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
697701
exception.message shouldBe expectedMessage
698702
}
699703
}
704+
705+
@Test
706+
fun testListPromptsPagination() = runBlocking(Dispatchers.IO) {
707+
val pagePrefix = "paginated-prompt-"
708+
(0 until 5).forEach { i ->
709+
val name = "$pagePrefix$i"
710+
server.addPrompt(name = name, description = "desc", arguments = listOf()) { _ ->
711+
GetPromptResult(description = "desc", messages = listOf(PromptMessage(role = Role.Assistant, content = TextContent(text = name))))
712+
}
713+
}
714+
715+
server.sessions.forEach { (_, session) ->
716+
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
717+
val all = server.prompts.values.map { it.prompt }
718+
val cursor = request.cursor?.toIntOrNull() ?: 0
719+
val pageSize = 2
720+
val page = all.drop(cursor).take(pageSize)
721+
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
722+
ListPromptsResult(prompts = page, nextCursor = next)
723+
}
724+
}
725+
726+
val allPrompts = mutableListOf<io.modelcontextprotocol.kotlin.sdk.types.Prompt>()
727+
var currentCursor: String? = null
728+
do {
729+
val request = if (currentCursor == null) ListPromptsRequest() else ListPromptsRequest(PaginatedRequestParams(cursor = currentCursor))
730+
val response = client.listPrompts(request)
731+
allPrompts.addAll(response.prompts)
732+
currentCursor = response.nextCursor
733+
} while (currentCursor != null)
734+
735+
assertTrue(allPrompts.any { it.name.startsWith(pagePrefix) })
736+
}
737+
738+
@Test
739+
fun testListPromptsInvalidCursor() = runBlocking(Dispatchers.IO) {
740+
server.sessions.forEach { (_, session) ->
741+
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
742+
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
743+
val all = server.prompts.values.map { it.prompt }
744+
val page = all.drop(cursor).take(2)
745+
ListPromptsResult(prompts = page, nextCursor = null)
746+
}
747+
}
748+
749+
val exception = kotlin.test.assertFailsWith<McpException> {
750+
client.listPrompts(ListPromptsRequest(PaginatedRequestParams(cursor = "not-a-number")))
751+
}
752+
753+
kotlin.test.assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
754+
}
700755
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin
22

33
import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents
4+
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest
5+
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult
46
import io.modelcontextprotocol.kotlin.sdk.types.McpException
7+
import io.modelcontextprotocol.kotlin.sdk.types.Method
8+
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
59
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
610
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
711
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams
@@ -309,4 +313,61 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
309313
assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty")
310314
}
311315
}
316+
317+
@Test
318+
fun testListResourcesPagination() = runBlocking(Dispatchers.IO) {
319+
val prefix = "paginated-resource-"
320+
(0 until 6).forEach { i ->
321+
val uri = "test://$prefix$i.txt"
322+
server.addResource(uri = uri, name = "Name-$i", description = "desc", mimeType = "text/plain") { request ->
323+
ReadResourceResult(contents = listOf(TextResourceContents(text = uri, uri = request.params.uri, mimeType = "text/plain")))
324+
}
325+
}
326+
327+
server.sessions.forEach { (_, session) ->
328+
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
329+
val all = server.resources.values.map { it.resource }
330+
val cursor = request.cursor?.toIntOrNull() ?: 0
331+
val pageSize = 3
332+
val page = all.drop(cursor).take(pageSize)
333+
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
334+
ListResourcesResult(resources = page, nextCursor = next)
335+
}
336+
}
337+
338+
val combinedUris = mutableListOf<String>()
339+
var currentCursor: String? = null
340+
341+
do {
342+
val request = if (currentCursor == null) {
343+
ListResourcesRequest()
344+
} else {
345+
ListResourcesRequest(PaginatedRequestParams(cursor = currentCursor))
346+
}
347+
348+
val response = client.listResources(request)
349+
combinedUris += response.resources.map { it.uri }
350+
currentCursor = response.nextCursor
351+
} while (currentCursor != null)
352+
353+
assertTrue(combinedUris.any { it.contains(prefix) })
354+
}
355+
356+
@Test
357+
fun testListResourcesInvalidCursor() = runBlocking(Dispatchers.IO) {
358+
server.sessions.forEach { (_, session) ->
359+
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
360+
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
361+
val all = server.resources.values.map { it.resource }
362+
val page = all.drop(cursor).take(2)
363+
ListResourcesResult(resources = page, nextCursor = null)
364+
}
365+
}
366+
367+
val exception = kotlin.test.assertFailsWith<McpException> {
368+
client.listResources(ListResourcesRequest(PaginatedRequestParams(cursor = "bad")))
369+
}
370+
371+
kotlin.test.assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
372+
}
312373
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
66
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
77
import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock
88
import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
9+
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest
10+
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
11+
import io.modelcontextprotocol.kotlin.sdk.types.Method
12+
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
913
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
1014
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
1115
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
@@ -791,4 +795,61 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
791795
"Error message should indicate the tool was not found",
792796
)
793797
}
798+
799+
@Test
800+
fun testListToolsPagination() = runBlocking(Dispatchers.IO) {
801+
val prefix = "paginated-tool-"
802+
(0 until 5).forEach { i ->
803+
val name = "$prefix$i"
804+
server.addTool(name = name, description = "desc") { request ->
805+
CallToolResult(content = listOf(TextContent(text = name)), structuredContent = buildJsonObject { put("name", name) })
806+
}
807+
}
808+
809+
server.sessions.forEach { (_, session) ->
810+
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
811+
val all = server.tools.values.map { it.tool }
812+
val cursor = request.cursor?.toIntOrNull() ?: 0
813+
val pageSize = 2
814+
val page = all.drop(cursor).take(pageSize)
815+
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
816+
ListToolsResult(tools = page, nextCursor = next)
817+
}
818+
}
819+
820+
val combinedNames = mutableListOf<String>()
821+
var currentCursor: String? = null
822+
823+
do {
824+
val request = if (currentCursor == null) {
825+
ListToolsRequest()
826+
} else {
827+
ListToolsRequest(PaginatedRequestParams(cursor = currentCursor))
828+
}
829+
830+
val response = client.listTools(request)
831+
combinedNames += response.tools.map { it.name }
832+
currentCursor = response.nextCursor
833+
} while (currentCursor != null)
834+
835+
assertTrue(combinedNames.any { it.startsWith(prefix) })
836+
}
837+
838+
@Test
839+
fun testListToolsInvalidCursor() = runBlocking(Dispatchers.IO) {
840+
server.sessions.forEach { (_, session) ->
841+
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
842+
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
843+
val all = server.tools.values.map { it.tool }
844+
val page = all.drop(cursor).take(2)
845+
ListToolsResult(tools = page)
846+
}
847+
}
848+
849+
val exception = kotlin.test.assertFailsWith<io.modelcontextprotocol.kotlin.sdk.types.McpException> {
850+
client.listTools(ListToolsRequest(PaginatedRequestParams(cursor = "bad")))
851+
}
852+
853+
kotlin.test.assertEquals(io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
854+
}
794855
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.streamablehttp
2+
3+
import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.KotlinTestBase
4+
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest
5+
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
6+
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
7+
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
8+
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams
9+
import io.modelcontextprotocol.kotlin.sdk.types.Method
10+
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
11+
import kotlinx.coroutines.CompletableDeferred
12+
import kotlinx.coroutines.runBlocking
13+
import kotlinx.coroutines.delay
14+
import kotlinx.coroutines.withTimeout
15+
import kotlinx.serialization.json.JsonObject
16+
import kotlinx.serialization.json.JsonPrimitive
17+
import org.junit.jupiter.api.Test
18+
19+
class LoggingIntegrationTestStreamableHttp : KotlinTestBase() {
20+
21+
override val transportKind = TransportKind.STREAMABLE_HTTP
22+
23+
override fun configureServerCapabilities(): ServerCapabilities = ServerCapabilities(
24+
tools = ServerCapabilities.Tools(listChanged = true),
25+
logging = JsonObject(emptyMap()),
26+
)
27+
28+
override fun configureServer() {
29+
server.addTool(name = "test-notification", description = "test") { request ->
30+
notification(
31+
LoggingMessageNotification(
32+
LoggingMessageNotificationParams(
33+
level = LoggingLevel.Info,
34+
data = JsonPrimitive("test-data-sample"),
35+
),
36+
),
37+
)
38+
io.modelcontextprotocol.kotlin.sdk.types.CallToolResult(listOf(io.modelcontextprotocol.kotlin.sdk.types.TextContent("ok")))
39+
}
40+
41+
server.addTool(name = "test-logging", description = "test") { request ->
42+
sendLoggingMessage(
43+
LoggingMessageNotification(
44+
LoggingMessageNotificationParams(
45+
level = LoggingLevel.Info,
46+
data = JsonObject(mapOf("key" to JsonPrimitive("value"))),
47+
),
48+
),
49+
)
50+
io.modelcontextprotocol.kotlin.sdk.types.CallToolResult(listOf(io.modelcontextprotocol.kotlin.sdk.types.TextContent("ok")))
51+
}
52+
53+
server.addTool(name = "test-logging-level", description = "test") { request ->
54+
LoggingLevel.entries.forEach { level ->
55+
sendLoggingMessage(
56+
LoggingMessageNotification(
57+
LoggingMessageNotificationParams(
58+
level = level,
59+
data = JsonPrimitive(level.name),
60+
),
61+
),
62+
)
63+
}
64+
io.modelcontextprotocol.kotlin.sdk.types.CallToolResult(listOf(io.modelcontextprotocol.kotlin.sdk.types.TextContent("ok")))
65+
}
66+
}
67+
68+
@Test
69+
fun `notification should send logging message to client`() = runBlocking {
70+
val notificationReceived = CompletableDeferred<LoggingMessageNotification>()
71+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
72+
notificationReceived.complete(it)
73+
CompletableDeferred(Unit)
74+
}
75+
76+
client.callTool(CallToolRequest(CallToolRequestParams("test-notification")))
77+
val received = notificationReceived.await()
78+
kotlin.test.assertEquals(LoggingLevel.Info, received.params.level)
79+
kotlin.test.assertEquals(JsonPrimitive("test-data-sample"), received.params.data)
80+
}
81+
82+
@Test
83+
fun `sendLoggingMessage should send message at level`() = runBlocking {
84+
val notificationReceived = CompletableDeferred<LoggingMessageNotification>()
85+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
86+
notificationReceived.complete(it)
87+
CompletableDeferred(Unit)
88+
}
89+
90+
client.callTool(CallToolRequest(CallToolRequestParams("test-logging")))
91+
val received = notificationReceived.await()
92+
kotlin.test.assertEquals(LoggingLevel.Info, received.params.level)
93+
kotlin.test.assertEquals(JsonObject(mapOf("key" to JsonPrimitive("value"))), received.params.data)
94+
}
95+
96+
@Test
97+
fun `sendLoggingMessage should filter messages below level`() = runBlocking {
98+
val receivedMessages = mutableListOf<LoggingMessageNotification>()
99+
client.setNotificationHandler<LoggingMessageNotification>(Method.Defined.NotificationsMessage) {
100+
receivedMessages.add(it)
101+
CompletableDeferred(Unit)
102+
}
103+
104+
client.setLoggingLevel(LoggingLevel.Warning)
105+
106+
client.callTool(CallToolRequest(CallToolRequestParams("test-logging-level")))
107+
108+
val expectedLevels = LoggingLevel.entries.filter { it >= LoggingLevel.Warning }
109+
// wait for expected notifications to arrive (transport may deliver asynchronously)
110+
withTimeout(2000) {
111+
while (receivedMessages.size < expectedLevels.size) {
112+
delay(10)
113+
}
114+
}
115+
kotlin.test.assertEquals(expectedLevels.size, receivedMessages.size)
116+
kotlin.test.assertEquals(expectedLevels.toList(), receivedMessages.map { it.params.level })
117+
}
118+
}

0 commit comments

Comments
 (0)