Skip to content

Commit 75a944f

Browse files
fix: polish SEP-1577 sampling handling (#765)
## Summary - reject empty sampling content in SamplingMessage and CreateMessageResult constructors - clarify sampling validation docs and emit an explicit error when tool_use is not followed by tool_result - return sampled text from the conformance test_sampling tool instead of content object strings Fixes #730 ## Testing - ./gradlew -Dorg.gradle.java.installations.paths="/Users/ubl/Documents/New project/oss-contributions/.jdks/jdk-21.0.11+10/Contents/Home" :kotlin-sdk-core:jvmTest :kotlin-sdk-server:jvmTest :conformance-test:compileKotlin Co-authored-by: Pavel Gorgulov <devcrocod@gmail.com>
1 parent d65c507 commit 75a944f

5 files changed

Lines changed: 56 additions & 3 deletions

File tree

  • conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance
  • kotlin-sdk-core/src
    • commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types
    • commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types
  • kotlin-sdk-server/src

conformance-test/src/main/kotlin/io/modelcontextprotocol/kotlin/sdk/conformance/ConformanceTools.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,13 @@ fun Server.registerConformanceTools() {
179179
),
180180
),
181181
)
182-
CallToolResult(listOf(TextContent(result.content.joinToString("\n") { it.toString() })))
182+
val sampledText = result.content.joinToString("\n") { content ->
183+
when (content) {
184+
is TextContent -> content.text
185+
else -> "Non-text sampling content: ${content::class.simpleName}"
186+
}
187+
}
188+
CallToolResult(listOf(TextContent(sampledText)))
183189
}
184190

185191
// 9. Elicitation

kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/types/sampling.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ public data class SamplingMessage(
9191
@SerialName("_meta")
9292
val meta: JsonObject? = null,
9393
) {
94+
init {
95+
require(content.isNotEmpty()) { "content must contain at least one block" }
96+
}
97+
9498
/**
9599
* Convenience constructor for a single-block message. Wraps [content] in a
96100
* singleton list so call sites can write `SamplingMessage(Role.User, TextContent("hi"))`
@@ -273,6 +277,10 @@ public data class CreateMessageResult(
273277
@SerialName("_meta")
274278
override val meta: JsonObject? = null,
275279
) : ClientResult {
280+
init {
281+
require(content.isNotEmpty()) { "content must contain at least one block" }
282+
}
283+
276284
/**
277285
* Convenience constructor for a single-block response. Wraps [content] in a
278286
* singleton list so call sites can write

kotlin-sdk-core/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/types/SamplingTest.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,13 @@ class SamplingTest {
302302
(m.content[0] as TextContent).text shouldBe "hi"
303303
}
304304

305+
@Test
306+
fun `SamplingMessage rejects empty content`() {
307+
assertFailsWith<IllegalArgumentException> {
308+
SamplingMessage(role = Role.User, content = emptyList())
309+
}
310+
}
311+
305312
@Test
306313
fun `SamplingMessage single-element content serialises as single object`() {
307314
val m = SamplingMessage(role = Role.User, content = listOf(TextContent("hi")))
@@ -408,6 +415,17 @@ class SamplingTest {
408415
(decoded.content[0] as TextContent).text shouldBe "hi"
409416
}
410417

418+
@Test
419+
fun `CreateMessageResult rejects empty content`() {
420+
assertFailsWith<IllegalArgumentException> {
421+
CreateMessageResult(
422+
role = Role.Assistant,
423+
content = emptyList(),
424+
model = "test-model",
425+
)
426+
}
427+
}
428+
411429
// ============================================================================
412430
// SamplingContentSerializer (single-or-array wire heuristic)
413431
// ============================================================================

kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingValidation.kt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent
2020
* 3. If the previous message contains `tool_use` blocks, the last message's
2121
* `tool_result` ids MUST form exactly the same set.
2222
*
23-
* On the first violation throws [IllegalArgumentException]. No-op when there are fewer
24-
* than two messages or no tool_use / tool_result blocks are involved.
23+
* On the first violation throws [IllegalArgumentException]. No-op when no
24+
* tool_use / tool_result blocks are involved.
2525
*/
2626
internal fun validateSamplingMessages(messages: List<SamplingMessage>) {
2727
if (messages.isEmpty()) return
@@ -44,6 +44,9 @@ internal fun validateSamplingMessages(messages: List<SamplingMessage>) {
4444
if (hasPreviousToolUse) {
4545
val toolUseIds = previous.filterIsInstance<ToolUseContent>().map { it.id }.toSet()
4646
val toolResultIds = last.filterIsInstance<ToolResultContent>().map { it.toolUseId }.toSet()
47+
require(toolResultIds.isNotEmpty()) {
48+
"tool_use blocks from previous message must be followed by matching tool_result blocks"
49+
}
4750
require(toolUseIds == toolResultIds) {
4851
"ids of tool_result blocks and tool_use blocks from previous message do not match"
4952
}

kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/SamplingTest.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import io.modelcontextprotocol.kotlin.sdk.types.ToolUseContent
88
import kotlinx.serialization.json.JsonObject
99
import org.junit.jupiter.api.assertDoesNotThrow
1010
import kotlin.test.Test
11+
import kotlin.test.assertEquals
1112
import kotlin.test.assertFailsWith
1213

1314
/**
@@ -83,4 +84,21 @@ class SamplingTest {
8384
)
8485
}
8586
}
87+
88+
@Test
89+
fun `validate tool_use requires explicit tool_result in last message`() {
90+
val error = assertFailsWith<IllegalArgumentException> {
91+
validateSamplingMessages(
92+
listOf(
93+
SamplingMessage(Role.Assistant, toolUse("c1")),
94+
SamplingMessage(Role.User, TextContent("missing result")),
95+
),
96+
)
97+
}
98+
99+
assertEquals(
100+
"tool_use blocks from previous message must be followed by matching tool_result blocks",
101+
error.message,
102+
)
103+
}
86104
}

0 commit comments

Comments
 (0)