Skip to content

Commit de280d9

Browse files
committed
feat(websockets): refine incremental repair logic for tool call responses
- Updated WebSocket response repair tests to validate incremental preservation of response calls and outputs. - Added new test cases for custom tool responses ensuring accurate handling of output cache and call cache. - Refactored `repairResponsesWebsocketToolCallsWithCaches` to handle orphan outputs more consistently. - Adjusted input filtering logic for clearer incremental repair behavior. Closes: router-for-me#3569
1 parent e399edd commit de280d9

2 files changed

Lines changed: 66 additions & 14 deletions

File tree

sdk/api/handlers/openai/openai_responses_websocket_test.go

Lines changed: 56 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *te
691691
}
692692
}
693693

694-
func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOutput(t *testing.T) {
694+
func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseOutputIncremental(t *testing.T) {
695695
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
696696
callCache := newWebsocketToolOutputCache(time.Minute, 10)
697697
sessionKey := "session-1"
@@ -705,17 +705,39 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOu
705705
t.Fatalf("previous_response_id = %q, want resp-latest", got)
706706
}
707707
input := gjson.GetBytes(repaired, "input").Array()
708-
if len(input) != 3 {
709-
t.Fatalf("repaired input len = %d, want 3: %s", len(input), repaired)
708+
if len(input) != 2 {
709+
t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired)
710710
}
711-
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
712-
t.Fatalf("missing inserted call: %s", input[0].Raw)
711+
if input[0].Get("type").String() != "function_call_output" || input[0].Get("call_id").String() != "call-1" {
712+
t.Fatalf("unexpected output item: %s", input[0].Raw)
713713
}
714-
if input[1].Get("type").String() != "function_call_output" || input[1].Get("call_id").String() != "call-1" {
715-
t.Fatalf("unexpected output item: %s", input[1].Raw)
714+
if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" {
715+
t.Fatalf("unexpected trailing item: %s", input[1].Raw)
716716
}
717-
if input[2].Get("type").String() != "message" || input[2].Get("id").String() != "msg-1" {
718-
t.Fatalf("unexpected trailing item: %s", input[2].Raw)
717+
}
718+
719+
func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCallIncremental(t *testing.T) {
720+
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
721+
callCache := newWebsocketToolOutputCache(time.Minute, 10)
722+
sessionKey := "session-1"
723+
724+
outputCache.record(sessionKey, "call-1", []byte(`{"type":"function_call_output","call_id":"call-1","id":"tool-out-1","output":"ok"}`))
725+
726+
raw := []byte(`{"previous_response_id":"resp-latest","input":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}`)
727+
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
728+
729+
if got := gjson.GetBytes(repaired, "previous_response_id").String(); got != "resp-latest" {
730+
t.Fatalf("previous_response_id = %q, want resp-latest", got)
731+
}
732+
input := gjson.GetBytes(repaired, "input").Array()
733+
if len(input) != 2 {
734+
t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired)
735+
}
736+
if input[0].Get("type").String() != "function_call" || input[0].Get("call_id").String() != "call-1" {
737+
t.Fatalf("unexpected call item: %s", input[0].Raw)
738+
}
739+
if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" {
740+
t.Fatalf("unexpected trailing item: %s", input[1].Raw)
719741
}
720742
}
721743

@@ -805,6 +827,31 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolCallForOrphanOu
805827
}
806828
}
807829

830+
func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCustomToolOutputIncremental(t *testing.T) {
831+
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
832+
callCache := newWebsocketToolOutputCache(time.Minute, 10)
833+
sessionKey := "session-1"
834+
835+
callCache.record(sessionKey, "call-1", []byte(`{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"}`))
836+
837+
raw := []byte(`{"previous_response_id":"resp-latest","input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}`)
838+
repaired := repairResponsesWebsocketToolCallsWithCaches(outputCache, callCache, sessionKey, raw)
839+
840+
if got := gjson.GetBytes(repaired, "previous_response_id").String(); got != "resp-latest" {
841+
t.Fatalf("previous_response_id = %q, want resp-latest", got)
842+
}
843+
input := gjson.GetBytes(repaired, "input").Array()
844+
if len(input) != 2 {
845+
t.Fatalf("repaired input len = %d, want 2: %s", len(input), repaired)
846+
}
847+
if input[0].Get("type").String() != "custom_tool_call_output" || input[0].Get("call_id").String() != "call-1" {
848+
t.Fatalf("unexpected output item: %s", input[0].Raw)
849+
}
850+
if input[1].Get("type").String() != "message" || input[1].Get("id").String() != "msg-1" {
851+
t.Fatalf("unexpected trailing item: %s", input[1].Raw)
852+
}
853+
}
854+
808855
func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolOutputWhenCallMissing(t *testing.T) {
809856
outputCache := newWebsocketToolOutputCache(time.Minute, 10)
810857
callCache := newWebsocketToolOutputCache(time.Minute, 10)

sdk/api/handlers/openai/openai_responses_websocket_toolcall_repair.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,11 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
305305
continue
306306
}
307307

308+
if allowOrphanOutputs {
309+
filtered = append(filtered, item)
310+
continue
311+
}
312+
308313
if callCache != nil {
309314
if cached, ok := callCache.get(sessionKey, callID); ok {
310315
if _, already := insertedCalls[callID]; !already {
@@ -317,11 +322,6 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
317322
}
318323
}
319324

320-
if allowOrphanOutputs {
321-
filtered = append(filtered, item)
322-
continue
323-
}
324-
325325
// Drop orphaned function_call_output items; upstream rejects transcripts with missing calls.
326326
continue
327327
}
@@ -341,6 +341,11 @@ func repairResponsesToolCallsArray(outputCache, callCache *websocketToolOutputCa
341341
continue
342342
}
343343

344+
if allowOrphanOutputs {
345+
filtered = append(filtered, item)
346+
continue
347+
}
348+
344349
if cached, ok := outputCache.get(sessionKey, callID); ok {
345350
filtered = append(filtered, item)
346351
filtered = append(filtered, cached)

0 commit comments

Comments
 (0)