Skip to content

Commit f23fb12

Browse files
committed
feat(translator): ensure tool uses stay adjacent to tool results in message generation
- Refactored `ConvertOpenAIResponsesRequestToClaude` logic to align tool use with corresponding tool results. - Introduced helper functions for appending and flushing pending reasoning and tool use messages. - Expanded tests to validate message order and content consistency when processing tool calls and results.
1 parent 644ba74 commit f23fb12

2 files changed

Lines changed: 86 additions & 5 deletions

File tree

internal/translator/claude/openai/responses/claude_openai-responses_request.go

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
170170

171171
// input array processing
172172
var pendingReasoningParts []string
173+
type pendingToolUseMessage struct {
174+
callID string
175+
raw []byte
176+
}
177+
var pendingToolUseMessages []pendingToolUseMessage
178+
appendMessage := func(msg []byte) {
179+
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
180+
}
173181
flushPendingReasoning := func() {
174182
if len(pendingReasoningParts) == 0 {
175183
return
@@ -178,9 +186,28 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
178186
for _, partJSON := range pendingReasoningParts {
179187
asst, _ = sjson.SetRawBytes(asst, "content.-1", []byte(partJSON))
180188
}
181-
out, _ = sjson.SetRawBytes(out, "messages.-1", asst)
189+
appendMessage(asst)
182190
pendingReasoningParts = nil
183191
}
192+
flushPendingToolUses := func() {
193+
for _, pending := range pendingToolUseMessages {
194+
appendMessage(pending.raw)
195+
}
196+
pendingToolUseMessages = nil
197+
}
198+
flushPendingToolUseFor := func(callID string) {
199+
if len(pendingToolUseMessages) == 0 {
200+
return
201+
}
202+
for i, pending := range pendingToolUseMessages {
203+
if pending.callID == callID {
204+
appendMessage(pending.raw)
205+
pendingToolUseMessages = append(pendingToolUseMessages[:i], pendingToolUseMessages[i+1:]...)
206+
return
207+
}
208+
}
209+
flushPendingToolUses()
210+
}
184211

185212
if input := root.Get("input"); input.Exists() && input.IsArray() {
186213
input.ForEach(func(_, item gjson.Result) bool {
@@ -294,6 +321,9 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
294321
}
295322

296323
hasReasoningParts := false
324+
if role != "assistant" {
325+
flushPendingToolUses()
326+
}
297327
if len(pendingReasoningParts) > 0 {
298328
if role == "assistant" {
299329
if len(partsJSON) == 0 && textAggregate.Len() > 0 {
@@ -322,12 +352,12 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
322352
msg, _ = sjson.SetRawBytes(msg, "content.-1", []byte(partJSON))
323353
}
324354
}
325-
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
355+
appendMessage(msg)
326356
} else if textAggregate.Len() > 0 || role == "system" {
327357
msg := []byte(`{"role":"","content":""}`)
328358
msg, _ = sjson.SetBytes(msg, "role", role)
329359
msg, _ = sjson.SetBytes(msg, "content", textAggregate.String())
330-
out, _ = sjson.SetRawBytes(out, "messages.-1", msg)
360+
appendMessage(msg)
331361
}
332362

333363
case "reasoning":
@@ -360,25 +390,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
360390
}
361391
pendingReasoningParts = nil
362392
asst, _ = sjson.SetRawBytes(asst, "content.-1", toolUse)
363-
out, _ = sjson.SetRawBytes(out, "messages.-1", asst)
393+
pendingToolUseMessages = append(pendingToolUseMessages, pendingToolUseMessage{
394+
callID: callID,
395+
raw: asst,
396+
})
364397

365398
case "function_call_output":
366399
flushPendingReasoning()
367400
// Map to user tool_result
368401
callID := item.Get("call_id").String()
402+
flushPendingToolUseFor(callID)
369403
outputStr := item.Get("output").String()
370404
toolResult := []byte(`{"type":"tool_result","tool_use_id":"","content":""}`)
371405
toolResult, _ = sjson.SetBytes(toolResult, "tool_use_id", callID)
372406
toolResult, _ = sjson.SetBytes(toolResult, "content", outputStr)
373407

374408
usr := []byte(`{"role":"user","content":[]}`)
375409
usr, _ = sjson.SetRawBytes(usr, "content.-1", toolResult)
376-
out, _ = sjson.SetRawBytes(out, "messages.-1", usr)
410+
appendMessage(usr)
377411
}
378412
return true
379413
})
380414
}
381415
flushPendingReasoning()
416+
flushPendingToolUses()
382417

383418
includedToolNames := map[string]struct{}{}
384419
toolNameMap := map[string]string{}

internal/translator/claude/openai/responses/claude_openai-responses_request_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,52 @@ func TestConvertOpenAIResponsesRequestToClaude_DropsIncompatibleReasoningSignatu
125125
}
126126
}
127127

128+
func TestConvertOpenAIResponsesRequestToClaude_KeepsToolUseAdjacentToToolResult(t *testing.T) {
129+
raw := []byte(`{
130+
"model":"claude-test",
131+
"input":[
132+
{
133+
"type":"function_call",
134+
"call_id":"call_00_awGuheXs4aRbtedNK8LE3743",
135+
"name":"js",
136+
"arguments":"{\"code\":\"nodeRepl.write('ok')\",\"title\":\"List Obsidian vault contents\"}"
137+
},
138+
{
139+
"type":"message",
140+
"role":"assistant",
141+
"content":[{"type":"output_text","text":"I'll check your Obsidian vault for articles."}]
142+
},
143+
{
144+
"type":"function_call_output",
145+
"call_id":"call_00_awGuheXs4aRbtedNK8LE3743",
146+
"output":"Wall time: 0.1963 seconds\nOutput:\n[{\"type\":\"text\",\"text\":\"\"}]"
147+
}
148+
]
149+
}`)
150+
151+
out := ConvertOpenAIResponsesRequestToClaude("claude-test", raw, false)
152+
root := gjson.ParseBytes(out)
153+
154+
if got := root.Get("messages.0.role").String(); got != "assistant" {
155+
t.Fatalf("first message role = %q, want assistant. Output: %s", got, string(out))
156+
}
157+
if got := root.Get("messages.0.content").String(); got != "I'll check your Obsidian vault for articles." {
158+
t.Fatalf("first message content = %q, want assistant text. Output: %s", got, string(out))
159+
}
160+
if got := root.Get("messages.1.content.0.type").String(); got != "tool_use" {
161+
t.Fatalf("second message first content type = %q, want tool_use. Output: %s", got, string(out))
162+
}
163+
if got := root.Get("messages.1.content.0.id").String(); got != "call_00_awGuheXs4aRbtedNK8LE3743" {
164+
t.Fatalf("tool_use id = %q, want call_00_awGuheXs4aRbtedNK8LE3743. Output: %s", got, string(out))
165+
}
166+
if got := root.Get("messages.2.content.0.type").String(); got != "tool_result" {
167+
t.Fatalf("third message first content type = %q, want tool_result. Output: %s", got, string(out))
168+
}
169+
if got := root.Get("messages.2.content.0.tool_use_id").String(); got != "call_00_awGuheXs4aRbtedNK8LE3743" {
170+
t.Fatalf("tool_result id = %q, want call_00_awGuheXs4aRbtedNK8LE3743. Output: %s", got, string(out))
171+
}
172+
}
173+
128174
func testClaudeResponsesThinkingSignature(t *testing.T) (string, string) {
129175
t.Helper()
130176
channelBlock := []byte{}

0 commit comments

Comments
 (0)