Skip to content

Commit d839b42

Browse files
committed
refactor(tests): enhance readability and type safety in integration tests
- Replaced `if` checks with `requireNotNull` for argument validation. - Updated formatting for response generation methods to improve clarity. - Sorted tools by name in `listTools` handler for consistent pagination testing. - Replaced `kotlin.test` prefix with direct imports for assertions.
1 parent 23e9bc1 commit d839b42

4 files changed

Lines changed: 47 additions & 20 deletions

File tree

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractPromptIntegrationTest.kt

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.Method
1313
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
1414
import io.modelcontextprotocol.kotlin.sdk.types.PromptArgument
1515
import io.modelcontextprotocol.kotlin.sdk.types.PromptMessage
16+
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
1617
import io.modelcontextprotocol.kotlin.sdk.types.Role
1718
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
1819
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
@@ -23,6 +24,7 @@ import kotlinx.coroutines.test.runTest
2324
import org.junit.jupiter.api.Test
2425
import org.junit.jupiter.api.assertThrows
2526
import kotlin.test.assertEquals
27+
import kotlin.test.assertFailsWith
2628
import kotlin.test.assertNotNull
2729
import kotlin.test.assertTrue
2830

@@ -161,8 +163,8 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
161163
// validate required arguments
162164
val requiredArgs = listOf("arg1", "arg2", "arg3")
163165
for (argName in requiredArgs) {
164-
if (request.params.arguments?.get(argName) == null) {
165-
throw IllegalArgumentException("Missing required argument: $argName")
166+
requireNotNull(request.params.arguments?.get(argName)) {
167+
"Missing required argument: $argName"
166168
}
167169
}
168170

@@ -708,7 +710,10 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
708710
(0 until 5).forEach { i ->
709711
val name = "$pagePrefix$i"
710712
server.addPrompt(name = name, description = "desc", arguments = listOf()) { _ ->
711-
GetPromptResult(description = "desc", messages = listOf(PromptMessage(role = Role.Assistant, content = TextContent(text = name))))
713+
GetPromptResult(
714+
description = "desc",
715+
messages = listOf(PromptMessage(role = Role.Assistant, content = TextContent(text = name))),
716+
)
712717
}
713718
}
714719

@@ -726,7 +731,11 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
726731
val allPrompts = mutableListOf<io.modelcontextprotocol.kotlin.sdk.types.Prompt>()
727732
var currentCursor: String? = null
728733
do {
729-
val request = if (currentCursor == null) ListPromptsRequest() else ListPromptsRequest(PaginatedRequestParams(cursor = currentCursor))
734+
val request = if (currentCursor == null) {
735+
ListPromptsRequest()
736+
} else {
737+
ListPromptsRequest(PaginatedRequestParams(cursor = currentCursor))
738+
}
730739
val response = client.listPrompts(request)
731740
allPrompts.addAll(response.prompts)
732741
currentCursor = response.nextCursor
@@ -739,17 +748,17 @@ abstract class AbstractPromptIntegrationTest : KotlinTestBase() {
739748
fun testListPromptsInvalidCursor() = runBlocking(Dispatchers.IO) {
740749
server.sessions.forEach { (_, session) ->
741750
session.setRequestHandler<ListPromptsRequest>(Method.Defined.PromptsList) { request, _ ->
742-
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
751+
val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" }
743752
val all = server.prompts.values.map { it.prompt }
744753
val page = all.drop(cursor).take(2)
745754
ListPromptsResult(prompts = page, nextCursor = null)
746755
}
747756
}
748757

749-
val exception = kotlin.test.assertFailsWith<McpException> {
758+
val exception = assertFailsWith<McpException> {
750759
client.listPrompts(ListPromptsRequest(PaginatedRequestParams(cursor = "not-a-number")))
751760
}
752761

753-
kotlin.test.assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
762+
assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
754763
}
755764
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractResourceIntegrationTest.kt

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,15 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
320320
(0 until 6).forEach { i ->
321321
val uri = "test://$prefix$i.txt"
322322
server.addResource(uri = uri, name = "Name-$i", description = "desc", mimeType = "text/plain") { request ->
323-
ReadResourceResult(contents = listOf(TextResourceContents(text = uri, uri = request.params.uri, mimeType = "text/plain")))
323+
ReadResourceResult(
324+
contents = listOf(
325+
TextResourceContents(
326+
text = uri,
327+
uri = request.params.uri,
328+
mimeType = "text/plain",
329+
),
330+
),
331+
)
324332
}
325333
}
326334

@@ -357,7 +365,7 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
357365
fun testListResourcesInvalidCursor() = runBlocking(Dispatchers.IO) {
358366
server.sessions.forEach { (_, session) ->
359367
session.setRequestHandler<ListResourcesRequest>(Method.Defined.ResourcesList) { request, _ ->
360-
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
368+
val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" }
361369
val all = server.resources.values.map { it.resource }
362370
val page = all.drop(cursor).take(2)
363371
ListResourcesResult(resources = page, nextCursor = null)
@@ -368,6 +376,6 @@ abstract class AbstractResourceIntegrationTest : KotlinTestBase() {
368376
client.listResources(ListResourcesRequest(PaginatedRequestParams(cursor = "bad")))
369377
}
370378

371-
kotlin.test.assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
379+
assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
372380
}
373381
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/AbstractToolIntegrationTest.kt

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ import io.modelcontextprotocol.kotlin.sdk.types.ContentBlock
88
import io.modelcontextprotocol.kotlin.sdk.types.ImageContent
99
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsRequest
1010
import io.modelcontextprotocol.kotlin.sdk.types.ListToolsResult
11+
import io.modelcontextprotocol.kotlin.sdk.types.McpException
1112
import io.modelcontextprotocol.kotlin.sdk.types.Method
1213
import io.modelcontextprotocol.kotlin.sdk.types.PaginatedRequestParams
14+
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
1315
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
1416
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
1517
import io.modelcontextprotocol.kotlin.sdk.types.ToolSchema
@@ -802,13 +804,16 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
802804
(0 until 5).forEach { i ->
803805
val name = "$prefix$i"
804806
server.addTool(name = name, description = "desc") { request ->
805-
CallToolResult(content = listOf(TextContent(text = name)), structuredContent = buildJsonObject { put("name", name) })
807+
CallToolResult(
808+
content = listOf(TextContent(text = name)),
809+
structuredContent = buildJsonObject { put("name", name) },
810+
)
806811
}
807812
}
808813

809814
server.sessions.forEach { (_, session) ->
810815
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
811-
val all = server.tools.values.map { it.tool }
816+
val all = server.tools.values.map { it.tool }.sortedBy { it.name }
812817
val cursor = request.cursor?.toIntOrNull() ?: 0
813818
val pageSize = 2
814819
val page = all.drop(cursor).take(pageSize)
@@ -832,24 +837,27 @@ abstract class AbstractToolIntegrationTest : KotlinTestBase() {
832837
currentCursor = response.nextCursor
833838
} while (currentCursor != null)
834839

835-
assertTrue(combinedNames.any { it.startsWith(prefix) })
840+
val paginatedNames = combinedNames.filter { it.startsWith(prefix) }
841+
assertEquals(5, paginatedNames.size, "All 5 paginated tools should appear")
842+
assertEquals(combinedNames.size, combinedNames.distinct().size, "No duplicate tools across pages")
843+
assertEquals(server.tools.size, combinedNames.size, "Total tools should match server registry")
836844
}
837845

838846
@Test
839847
fun testListToolsInvalidCursor() = runBlocking(Dispatchers.IO) {
840848
server.sessions.forEach { (_, session) ->
841849
session.setRequestHandler<ListToolsRequest>(Method.Defined.ToolsList) { request, _ ->
842-
val cursor = request.cursor?.toIntOrNull() ?: throw IllegalArgumentException("Invalid cursor")
850+
val cursor = requireNotNull(request.cursor?.toIntOrNull()) { "Invalid cursor" }
843851
val all = server.tools.values.map { it.tool }
844852
val page = all.drop(cursor).take(2)
845853
ListToolsResult(tools = page)
846854
}
847855
}
848856

849-
val exception = kotlin.test.assertFailsWith<io.modelcontextprotocol.kotlin.sdk.types.McpException> {
857+
val exception = kotlin.test.assertFailsWith<McpException> {
850858
client.listTools(ListToolsRequest(PaginatedRequestParams(cursor = "bad")))
851859
}
852860

853-
kotlin.test.assertEquals(io.modelcontextprotocol.kotlin.sdk.types.RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
861+
assertEquals(RPCError.ErrorCode.INTERNAL_ERROR, exception.code)
854862
}
855863
}

integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/kotlin/streamablehttp/LoggingIntegrationTestStreamableHttp.kt

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,16 @@ package io.modelcontextprotocol.kotlin.sdk.integration.kotlin.streamablehttp
33
import io.modelcontextprotocol.kotlin.sdk.integration.kotlin.KotlinTestBase
44
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequest
55
import io.modelcontextprotocol.kotlin.sdk.types.CallToolRequestParams
6+
import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult
67
import io.modelcontextprotocol.kotlin.sdk.types.LoggingLevel
78
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotification
89
import io.modelcontextprotocol.kotlin.sdk.types.LoggingMessageNotificationParams
910
import io.modelcontextprotocol.kotlin.sdk.types.Method
1011
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
12+
import io.modelcontextprotocol.kotlin.sdk.types.TextContent
1113
import kotlinx.coroutines.CompletableDeferred
12-
import kotlinx.coroutines.runBlocking
1314
import kotlinx.coroutines.delay
15+
import kotlinx.coroutines.runBlocking
1416
import kotlinx.coroutines.withTimeout
1517
import kotlinx.serialization.json.JsonObject
1618
import kotlinx.serialization.json.JsonPrimitive
@@ -35,7 +37,7 @@ class LoggingIntegrationTestStreamableHttp : KotlinTestBase() {
3537
),
3638
),
3739
)
38-
io.modelcontextprotocol.kotlin.sdk.types.CallToolResult(listOf(io.modelcontextprotocol.kotlin.sdk.types.TextContent("ok")))
40+
CallToolResult(listOf(TextContent("ok")))
3941
}
4042

4143
server.addTool(name = "test-logging", description = "test") { request ->
@@ -47,7 +49,7 @@ class LoggingIntegrationTestStreamableHttp : KotlinTestBase() {
4749
),
4850
),
4951
)
50-
io.modelcontextprotocol.kotlin.sdk.types.CallToolResult(listOf(io.modelcontextprotocol.kotlin.sdk.types.TextContent("ok")))
52+
CallToolResult(listOf(TextContent("ok")))
5153
}
5254

5355
server.addTool(name = "test-logging-level", description = "test") { request ->
@@ -61,7 +63,7 @@ class LoggingIntegrationTestStreamableHttp : KotlinTestBase() {
6163
),
6264
)
6365
}
64-
io.modelcontextprotocol.kotlin.sdk.types.CallToolResult(listOf(io.modelcontextprotocol.kotlin.sdk.types.TextContent("ok")))
66+
CallToolResult(listOf(TextContent("ok")))
6567
}
6668
}
6769

0 commit comments

Comments
 (0)