Skip to content

Commit 280fc89

Browse files
authored
Merge pull request #11 from Deep-CodeAI/fix/850-ollama-json-escape
fix(#850): escape every interpolated string in OllamaClient.buildRequ…
2 parents 693d244 + 8b82e01 commit 280fc89

2 files changed

Lines changed: 100 additions & 4 deletions

File tree

src/main/kotlin/agents_engine/model/OllamaClient.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ open class OllamaClient(
137137
internal fun buildRequestJson(messages: List<LlmMessage>, includeTools: Boolean = true): String {
138138
val messagesJson = messages.joinToString(",") { msg ->
139139
buildString {
140-
append("""{"role":"${msg.role}","content":${msg.content.toJsonString()}""")
140+
append("""{"role":${msg.role.toJsonString()},"content":${msg.content.toJsonString()}""")
141141
if (!msg.toolCalls.isNullOrEmpty()) {
142142
append(""","tool_calls":[""")
143143
append(msg.toolCalls.joinToString(",") { tc ->
144-
"""{"function":{"name":"${tc.name}","arguments":${InlineToolCallParser.argsToJson(tc.arguments)}}}"""
144+
"""{"function":{"name":${tc.name.toJsonString()},"arguments":${InlineToolCallParser.argsToJson(tc.arguments)}}}"""
145145
})
146146
append("]")
147147
}
@@ -152,11 +152,11 @@ open class OllamaClient(
152152
val defs = tools.joinToString(",") { t ->
153153
val parametersJson = t.argsType?.jsonSchema()
154154
?: """{"type":"object","properties":{},"additionalProperties":true}"""
155-
"""{"type":"function","function":{"name":"${t.name}","description":${t.description.toJsonString()},"parameters":$parametersJson}}"""
155+
"""{"type":"function","function":{"name":${t.name.toJsonString()},"description":${t.description.toJsonString()},"parameters":$parametersJson}}"""
156156
}
157157
""","tools":[$defs]"""
158158
} else ""
159-
return """{"model":"$model","stream":false,"temperature":$temperature,"messages":[$messagesJson]$toolsJson}"""
159+
return """{"model":${model.toJsonString()},"stream":false,"temperature":$temperature,"messages":[$messagesJson]$toolsJson}"""
160160
}
161161

162162
internal fun parseResponse(body: String): LlmResponse {
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package agents_engine.model
2+
3+
import agents_engine.generation.LenientJsonParser
4+
import kotlin.test.Test
5+
import kotlin.test.assertEquals
6+
import kotlin.test.assertNotNull
7+
import kotlin.test.assertTrue
8+
9+
// Tests for #850 — every interpolated string in OllamaClient.buildRequestJson
10+
// must go through toJsonString(). A value containing `"` or `\` must round-trip
11+
// through LenientJsonParser cleanly.
12+
class OllamaRequestJsonEscapingTest {
13+
14+
private fun parsedRequest(
15+
model: String = "test-model",
16+
tools: List<ToolDef> = emptyList(),
17+
messages: List<LlmMessage> = listOf(LlmMessage("user", "hi")),
18+
): Map<String, Any?> {
19+
val client = OllamaClient(host = "localhost", port = 11434, model = model, tools = tools)
20+
val body = client.buildRequestJson(messages)
21+
@Suppress("UNCHECKED_CAST")
22+
return LenientJsonParser.parse(body) as? Map<String, Any?>
23+
?: error("OllamaClient produced unparseable JSON: $body")
24+
}
25+
26+
@Test
27+
fun `tool name with embedded quote and backslash round-trips through escaping`() {
28+
val def = ToolDef(name = """foo"bar\baz""", description = "ok") { _ -> "x" }
29+
val body = parsedRequest(tools = listOf(def))
30+
@Suppress("UNCHECKED_CAST")
31+
val tools = body["tools"] as List<Map<String, Any?>>
32+
@Suppress("UNCHECKED_CAST")
33+
val fn = tools.single()["function"] as Map<String, Any?>
34+
assertEquals("""foo"bar\baz""", fn["name"])
35+
}
36+
37+
@Test
38+
fun `model name containing special chars round-trips through escaping`() {
39+
val body = parsedRequest(model = """my"weird\model""")
40+
assertEquals("""my"weird\model""", body["model"])
41+
}
42+
43+
@Test
44+
fun `assistant tool_call name with special chars does not corrupt next-turn JSON`() {
45+
// Simulates an LLM emitting a tool call whose name contains a quote — this is
46+
// the self-injection vector. The next-turn request must still parse.
47+
val toolCalls = listOf(ToolCall(name = """ev"il""", arguments = mapOf("x" to "y")))
48+
val body = parsedRequest(
49+
messages = listOf(
50+
LlmMessage("user", "hi"),
51+
LlmMessage("assistant", "", toolCalls),
52+
LlmMessage("tool", "result"),
53+
),
54+
)
55+
@Suppress("UNCHECKED_CAST")
56+
val msgs = body["messages"] as List<Map<String, Any?>>
57+
val assistant = msgs[1]
58+
@Suppress("UNCHECKED_CAST")
59+
val calls = assistant["tool_calls"] as List<Map<String, Any?>>
60+
@Suppress("UNCHECKED_CAST")
61+
val fn = calls.single()["function"] as Map<String, Any?>
62+
assertEquals("""ev"il""", fn["name"])
63+
}
64+
65+
@Test
66+
fun `message role is escaped (defensive against future role extensions)`() {
67+
// Role is currently framework-controlled but the escape path must still honor
68+
// weird inputs to prevent regressions if roles ever become user-extensible.
69+
val body = parsedRequest(messages = listOf(LlmMessage("""sys"tem""", "hi")))
70+
@Suppress("UNCHECKED_CAST")
71+
val msgs = body["messages"] as List<Map<String, Any?>>
72+
assertEquals("""sys"tem""", msgs.single()["role"])
73+
}
74+
75+
@Test
76+
fun `newline and tab in tool description still escape correctly (regression)`() {
77+
val def = ToolDef(name = "ok", description = "line1\nline2\tindent") { _ -> "x" }
78+
val body = parsedRequest(tools = listOf(def))
79+
@Suppress("UNCHECKED_CAST")
80+
val tools = body["tools"] as List<Map<String, Any?>>
81+
@Suppress("UNCHECKED_CAST")
82+
val fn = tools.single()["function"] as Map<String, Any?>
83+
assertEquals("line1\nline2\tindent", fn["description"])
84+
}
85+
86+
@Test
87+
fun `request JSON is parseable when nothing special needs escaping (sanity)`() {
88+
val body = parsedRequest(
89+
tools = listOf(ToolDef(name = "plain", description = "plain") { _ -> "x" }),
90+
)
91+
assertEquals("test-model", body["model"])
92+
assertNotNull(body["messages"])
93+
assertNotNull(body["tools"])
94+
assertTrue(body.containsKey("stream"))
95+
}
96+
}

0 commit comments

Comments
 (0)