Skip to content

Commit f57e988

Browse files
committed
feat(cast): add JSON control character sanitization for tool call arguments
- Implemented `SanitizeJSONControlChars` to escape literal control characters in JSON string values, ensuring compliance with the JSON specification. - Enhanced `SanitizeToolCallArguments` method to apply sanitization across tool call arguments in the chain. - Added comprehensive tests for both sanitization functions to validate behavior with various input scenarios.
1 parent 9cd52b1 commit f57e988

4 files changed

Lines changed: 279 additions & 0 deletions

File tree

backend/pkg/cast/chain_ast.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cast
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"pentagi/pkg/templates"
67
"sort"
@@ -1056,6 +1057,104 @@ func ContainsToolCallReasoning(messages []llms.MessageContent) bool {
10561057
return false
10571058
}
10581059

1060+
// SanitizeJSONControlChars escapes any literal control characters (0x00–0x1F, 0x7F) that
1061+
// appear inside JSON string values. Such characters are disallowed by the JSON spec and
1062+
// must be encoded as \n, \r, \t, or \uXXXX escape sequences.
1063+
//
1064+
// The scanner is context-aware: it tracks whether it is currently inside a JSON string and
1065+
// skips over already-escaped sequences so they are never double-escaped.
1066+
func SanitizeJSONControlChars(s string) string {
1067+
if json.Valid([]byte(s)) {
1068+
return s
1069+
}
1070+
1071+
var buf strings.Builder
1072+
buf.Grow(len(s) + 32)
1073+
1074+
inString := false
1075+
escaped := false
1076+
1077+
for i := 0; i < len(s); {
1078+
b := s[i]
1079+
1080+
if escaped {
1081+
buf.WriteByte(b)
1082+
i++
1083+
escaped = false
1084+
continue
1085+
}
1086+
1087+
if inString && b == '\\' {
1088+
buf.WriteByte(b)
1089+
i++
1090+
escaped = true
1091+
continue
1092+
}
1093+
1094+
if b == '"' {
1095+
buf.WriteByte(b)
1096+
i++
1097+
inString = !inString
1098+
continue
1099+
}
1100+
1101+
if inString && b < 0x20 {
1102+
switch b {
1103+
case '\n':
1104+
buf.WriteString(`\n`)
1105+
case '\r':
1106+
buf.WriteString(`\r`)
1107+
case '\t':
1108+
buf.WriteString(`\t`)
1109+
case '\b':
1110+
buf.WriteString(`\b`)
1111+
case '\f':
1112+
buf.WriteString(`\f`)
1113+
default:
1114+
fmt.Fprintf(&buf, `\u%04x`, b)
1115+
}
1116+
i++
1117+
continue
1118+
}
1119+
1120+
buf.WriteByte(b)
1121+
i++
1122+
}
1123+
1124+
return buf.String()
1125+
}
1126+
1127+
// SanitizeToolCallArguments scans every tool call in the chain and ensures its Arguments
1128+
// field is valid JSON. Literal control characters that some LLM providers occasionally
1129+
// emit inside string values are escaped so that downstream consumers (e.g. vLLM's
1130+
// _postprocess_messages) can parse the arguments without errors.
1131+
func (ast *ChainAST) SanitizeToolCallArguments() {
1132+
for _, section := range ast.Sections {
1133+
for _, bodyPair := range section.Body {
1134+
if bodyPair.Type != RequestResponse && bodyPair.Type != Summarization {
1135+
continue
1136+
}
1137+
1138+
if bodyPair.AIMessage == nil {
1139+
continue
1140+
}
1141+
1142+
for pdx, part := range bodyPair.AIMessage.Parts {
1143+
toolCall, ok := part.(llms.ToolCall)
1144+
if !ok || toolCall.FunctionCall == nil {
1145+
continue
1146+
}
1147+
1148+
sanitized := SanitizeJSONControlChars(toolCall.FunctionCall.Arguments)
1149+
if sanitized != toolCall.FunctionCall.Arguments {
1150+
toolCall.FunctionCall.Arguments = sanitized
1151+
bodyPair.AIMessage.Parts[pdx] = toolCall
1152+
}
1153+
}
1154+
}
1155+
}
1156+
}
1157+
10591158
// ExtractReasoningMessage extracts the first AI message that contains reasoning content
10601159
// in a TextContent part. This is useful for preserving reasoning messages when summarizing
10611160
// content for providers like Kimi (Moonshot) that require reasoning_content before ToolCall.

backend/pkg/cast/chain_ast_test.go

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,3 +3015,178 @@ func TestClearReasoning_IntegrationWithNormalize(t *testing.T) {
30153015

30163016
t.Log("Successfully normalized IDs and cleared reasoning for provider switch")
30173017
}
3018+
3019+
func TestSanitizeJSONControlChars(t *testing.T) {
3020+
tests := []struct {
3021+
name string
3022+
input string
3023+
want string
3024+
}{
3025+
{
3026+
name: "already valid JSON - no changes",
3027+
input: `{"input": "hello world"}`,
3028+
want: `{"input": "hello world"}`,
3029+
},
3030+
{
3031+
name: "already valid JSON with escaped newline - no changes",
3032+
input: `{"input": "line1\nline2"}`,
3033+
want: `{"input": "line1\nline2"}`,
3034+
},
3035+
{
3036+
name: "literal newline inside string value",
3037+
input: "{\"input\": \"line1\nline2\"}",
3038+
want: `{"input": "line1\nline2"}`,
3039+
},
3040+
{
3041+
name: "literal carriage return inside string value",
3042+
input: "{\"cmd\": \"echo\rtest\"}",
3043+
want: `{"cmd": "echo\rtest"}`,
3044+
},
3045+
{
3046+
name: "literal tab inside string value",
3047+
input: "{\"v\": \"a\tb\"}",
3048+
want: `{"v": "a\tb"}`,
3049+
},
3050+
{
3051+
name: "literal backspace and form feed inside string value",
3052+
input: "{\"v\": \"a\x08\x0Cb\"}",
3053+
want: `{"v": "a\b\fb"}`,
3054+
},
3055+
{
3056+
name: `other control character (SOH) gets \uXXXX encoding`,
3057+
input: "{\"v\": \"a\x01b\"}",
3058+
want: `{"v": "a\u0001b"}`,
3059+
},
3060+
{
3061+
name: "control character outside string - not touched",
3062+
input: "{\n\"k\": \"v\"}",
3063+
want: "{\n\"k\": \"v\"}",
3064+
},
3065+
{
3066+
name: "multiple string fields, only affected one fixed",
3067+
input: "{\"a\": \"ok\", \"b\": \"bad\nval\"}",
3068+
want: `{"a": "ok", "b": "bad\nval"}`,
3069+
},
3070+
{
3071+
name: "escaped backslash before quote not confused with string end",
3072+
input: "{\"v\": \"path\\\\dir\"}",
3073+
want: `{"v": "path\\dir"}`,
3074+
},
3075+
{
3076+
name: "real-world vLLM crash case: index.html newline in arguments",
3077+
input: "{\"input\": \"index.html] =\ncurl -s http://example.com\", \"cwd\": \"/work\"}",
3078+
want: `{"input": "index.html] =\ncurl -s http://example.com", "cwd": "/work"}`,
3079+
},
3080+
{
3081+
name: "empty string",
3082+
input: "",
3083+
want: "",
3084+
},
3085+
{
3086+
name: "plain invalid JSON without control chars - returned as-is",
3087+
input: `{broken`,
3088+
want: `{broken`,
3089+
},
3090+
}
3091+
3092+
for _, tt := range tests {
3093+
t.Run(tt.name, func(t *testing.T) {
3094+
got := SanitizeJSONControlChars(tt.input)
3095+
assert.Equal(t, tt.want, got)
3096+
})
3097+
}
3098+
}
3099+
3100+
func TestSanitizeToolCallArguments(t *testing.T) {
3101+
makeChain := func(args string) []llms.MessageContent {
3102+
return []llms.MessageContent{
3103+
{
3104+
Role: llms.ChatMessageTypeSystem,
3105+
Parts: []llms.ContentPart{llms.TextContent{Text: "system"}},
3106+
},
3107+
{
3108+
Role: llms.ChatMessageTypeHuman,
3109+
Parts: []llms.ContentPart{llms.TextContent{Text: "do it"}},
3110+
},
3111+
{
3112+
Role: llms.ChatMessageTypeAI,
3113+
Parts: []llms.ContentPart{
3114+
llms.ToolCall{
3115+
ID: "call_abc",
3116+
Type: "function",
3117+
FunctionCall: &llms.FunctionCall{
3118+
Name: "terminal",
3119+
Arguments: args,
3120+
},
3121+
},
3122+
},
3123+
},
3124+
{
3125+
Role: llms.ChatMessageTypeTool,
3126+
Parts: []llms.ContentPart{
3127+
llms.ToolCallResponse{ToolCallID: "call_abc", Name: "terminal", Content: "ok"},
3128+
},
3129+
},
3130+
}
3131+
}
3132+
3133+
getArgs := func(ast *ChainAST) string {
3134+
for _, section := range ast.Sections {
3135+
for _, pair := range section.Body {
3136+
if pair.AIMessage == nil {
3137+
continue
3138+
}
3139+
for _, part := range pair.AIMessage.Parts {
3140+
if tc, ok := part.(llms.ToolCall); ok && tc.FunctionCall != nil {
3141+
return tc.FunctionCall.Arguments
3142+
}
3143+
}
3144+
}
3145+
}
3146+
return ""
3147+
}
3148+
3149+
tests := []struct {
3150+
name string
3151+
args string
3152+
wantArgs string
3153+
}{
3154+
{
3155+
name: "valid JSON - unchanged",
3156+
args: `{"input": "ls -la", "cwd": "/work"}`,
3157+
wantArgs: `{"input": "ls -la", "cwd": "/work"}`,
3158+
},
3159+
{
3160+
name: "literal newline in argument value - gets escaped",
3161+
args: "{\"input\": \"curl -s http://example.com\ncurl -s http://other.com\", \"cwd\": \"/work\"}",
3162+
wantArgs: `{"input": "curl -s http://example.com\ncurl -s http://other.com", "cwd": "/work"}`,
3163+
},
3164+
{
3165+
name: "multiple control chars - all escaped",
3166+
args: "{\"input\": \"a\nb\rc\", \"cwd\": \"/work\"}",
3167+
wantArgs: `{"input": "a\nb\rc", "cwd": "/work"}`,
3168+
},
3169+
{
3170+
name: "already escaped newline in value - unchanged",
3171+
args: `{"input": "line1\nline2", "cwd": "/work"}`,
3172+
wantArgs: `{"input": "line1\nline2", "cwd": "/work"}`,
3173+
},
3174+
{
3175+
name: "nil FunctionCall tool call does not panic",
3176+
args: `{}`,
3177+
wantArgs: `{}`,
3178+
},
3179+
}
3180+
3181+
for _, tt := range tests {
3182+
t.Run(tt.name, func(t *testing.T) {
3183+
chain := makeChain(tt.args)
3184+
ast, err := NewChainAST(chain, false)
3185+
assert.NoError(t, err)
3186+
3187+
ast.SanitizeToolCallArguments()
3188+
3189+
assert.Equal(t, tt.wantArgs, getArgs(ast))
3190+
})
3191+
}
3192+
}

backend/pkg/providers/helpers.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,8 @@ func (fp *flowProvider) restoreChain(
523523
return wrapErrorWithEvent("failed to normalize tool call IDs", err)
524524
}
525525

526+
ast.SanitizeToolCallArguments()
527+
526528
if err := ast.ClearReasoning(); err != nil {
527529
return wrapErrorWithEvent("failed to clear reasoning", err)
528530
}
@@ -631,6 +633,8 @@ func (fp *flowProvider) processChain(
631633
logger.WithError(err).Warn("failed to normalize tool call IDs")
632634
}
633635

636+
ast.SanitizeToolCallArguments()
637+
634638
// Clear provider-specific reasoning signatures
635639
if err := ast.ClearReasoning(); err != nil {
636640
logger.WithError(err).Warn("failed to clear reasoning")

backend/pkg/providers/performer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ func (fp *flowProvider) callWithRetries(
444444
if toolCall.FunctionCall == nil {
445445
continue
446446
}
447+
toolCall.FunctionCall.Arguments = cast.SanitizeJSONControlChars(toolCall.FunctionCall.Arguments)
447448
result.funcCalls = append(result.funcCalls, toolCall)
448449
}
449450
}

0 commit comments

Comments
 (0)