Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import io.kotest.matchers.string.shouldContain
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequest
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.GetPromptResult
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListPromptsResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.Prompt
import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument
import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage
import io.modelcontextprotocol.kotlin.sdk.types.Role
Expand All @@ -19,6 +25,7 @@ import kotlinx.coroutines.test.runTest
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -697,4 +704,56 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
exception.message shouldBe expectedMessage
}
}

@Test
fun testListPromptsPagination() = runBlocking(Dispatchers.IO) {
val pagePrefix = "paginated-prompt-"
(0 until 5).forEach { i ->
val name = "$pagePrefix$i"
server.addPrompt(name = name, description = "desc", arguments = listOf()) { _ ->
GetPromptResult(description = "desc", messages = listOf(PromptMessage(role = Role.Assistant, content = TextContent(text = name))))
}
}

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
val all = server.prompts.values.map { it.prompt }
val cursor = request.cursor?.toIntOrNull() ?: 0
val pageSize = 2
val page = all.drop(cursor).take(pageSize)
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
ListPromptsResult(prompts = page, nextCursor = next)
}
Comment on lines +717 to +725
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not entirely clear to me what exactly is being verified in such tests, since the test effectively reproduces the same logic that it is supposed to validate

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point — the original test was just re-running the same drop/take logic the client was supposed to be exercising. Refactored to a hardcoded when(cursor) handler that returns three pre-built pages with explicit cursor-2 / cursor-3 tokens, and the test now records every cursor the client sends and asserts the full sequence ([null, cursor-2, cursor-3]) plus the exact accumulated list. The server side is no longer the logic under test.

}

val allPrompts = mutableListOf<Prompt>()
var currentCursor: String? = null
do {
val request = if (currentCursor == null) ListPromptsRequest() else ListPromptsRequest(PaginatedRequestParams(cursor = currentCursor))
val response = client.listPrompts(request)
allPrompts.addAll(response.prompts)
currentCursor = response.nextCursor
} while (currentCursor != null)

val paginatedPrompts = allPrompts.filter { it.name.startsWith(pagePrefix) }
assertEquals(5, paginatedPrompts.size, "Should have collected all 5 paginated prompts")
Comment thread
devcrocod marked this conversation as resolved.
Outdated
}

@Test
fun testListPromptsInvalidCursor() = runBlocking(Dispatchers.IO) {
server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
val all = server.prompts.values.map { it.prompt }
val page = all.drop(cursor).take(2)
ListPromptsResult(prompts = page, nextCursor = null)
}
}

val exception = assertFailsWith<McpException> {
client.listPrompts(ListPromptsRequest(PaginatedRequestParams(cursor = "not-a-number")))
}

assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
Comment thread
devcrocod marked this conversation as resolved.
Outdated
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package io.modelcontextprotocol.kotlin.sdk.integration.kotlin

import io.modelcontextprotocol.kotlin.sdk.types.BlobResourceContents
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListResourcesResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequest
import io.modelcontextprotocol.kotlin.sdk.types.ReadResourceRequestParams
Expand All @@ -20,6 +24,7 @@ import org.junit.jupiter.api.Test
import org.junit.jupiter.api.assertThrows
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -309,4 +314,62 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
assertTrue(result.contents.isNotEmpty(), "Result contents should not be empty")
}
}

@Test
fun testListResourcesPagination() = runBlocking(Dispatchers.IO) {
val prefix = "paginated-resource-"
(0 until 6).forEach { i ->
val uri = "test://$prefix$i.txt"
server.addResource(uri = uri, name = "Name-$i", description = "desc", mimeType = "text/plain") { request ->
ReadResourceResult(contents = listOf(TextResourceContents(text = uri, uri = request.params.uri, mimeType = "text/plain")))
}
}

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
val all = server.resources.values.map { it.resource }
val cursor = request.cursor?.toIntOrNull() ?: 0
val pageSize = 3
val page = all.drop(cursor).take(pageSize)
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
ListResourcesResult(resources = page, nextCursor = next)
}
}

val combinedUris = mutableListOf<String>()
var currentCursor: String? = null

do {
val request = if (currentCursor == null) {
ListResourcesRequest()
} else {
ListResourcesRequest(PaginatedRequestParams(cursor = currentCursor))
}

val response = client.listResources(request)
combinedUris += response.resources.map { it.uri }
currentCursor = response.nextCursor
} while (currentCursor != null)

val paginatedResources = combinedUris.filter { it.contains(prefix) }
assertEquals(6, paginatedResources.size, "Should have collected all 6 paginated resources")
}

@Test
fun testListResourcesInvalidCursor() = runBlocking(Dispatchers.IO) {
server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
val all = server.resources.values.map { it.resource }
val page = all.drop(cursor).take(2)
ListResourcesResult(resources = page, nextCursor = null)
}
}

val exception = assertFailsWith<McpException> {
client.listResources(ListResourcesRequest(PaginatedRequestParams(cursor = "bad")))
}

assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@ import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock
import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
import io.modelcontextprotocol.kotlin.sdk.types.McpException
import io.modelcontextprotocol.kotlin.sdk.types.Method
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
Expand All @@ -25,6 +31,7 @@ import java.text.DecimalFormat
import java.text.DecimalFormatSymbols
import java.util.Locale
import kotlin.test.assertEquals
import kotlin.test.assertFailsWith
import kotlin.test.assertNotNull
import kotlin.test.assertTrue

Expand Down Expand Up @@ -791,4 +798,62 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
"Error message should indicate the tool was not found",
)
}

@Test
fun testListToolsPagination() = runBlocking(Dispatchers.IO) {
val prefix = "paginated-tool-"
(0 until 5).forEach { i ->
val name = "$prefix$i"
server.addTool(name = name, description = "desc") { request ->
CallToolResult(content = listOf(TextContent(text = name)), structuredContent = buildJsonObject { put("name", name) })
}
}

server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
val all = server.tools.values.map { it.tool }
val cursor = request.cursor?.toIntOrNull() ?: 0
val pageSize = 2
val page = all.drop(cursor).take(pageSize)
val next = if (cursor + page.size < all.size) (cursor + page.size).toString() else null
ListToolsResult(tools = page, nextCursor = next)
}
}

val combinedNames = mutableListOf<String>()
var currentCursor: String? = null

do {
val request = if (currentCursor == null) {
ListToolsRequest()
} else {
ListToolsRequest(PaginatedRequestParams(cursor = currentCursor))
}

val response = client.listTools(request)
combinedNames += response.tools.map { it.name }
currentCursor = response.nextCursor
} while (currentCursor != null)

val paginatedTools = combinedNames.filter { it.startsWith(prefix) }
assertEquals(5, paginatedTools.size, "Should have collected all 5 paginated tools")
}

@Test
fun testListToolsInvalidCursor() = runBlocking(Dispatchers.IO) {
server.sessions.forEach { (_, session) ->
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
val all = server.tools.values.map { it.tool }
val page = all.drop(cursor).take(2)
ListToolsResult(tools = page)
}
}

val exception = assertFailsWith<McpException> {
client.listTools(ListToolsRequest(PaginatedRequestParams(cursor = "bad")))
}

assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
}
}
Loading