Skip to content

Commit 55901f0

Browse files
authored
Merge pull request router-for-me#3620 from iBenzene/fix/responses-input-id-dedupe
fix(openai): dedupe response websocket input item IDs
2 parents fc0615b + e9dafc7 commit 55901f0

2 files changed

Lines changed: 103 additions & 0 deletions

File tree

sdk/api/handlers/openai/openai_responses_websocket.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
381381
}
382382

383383
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
384+
requestJSON = dedupeResponsesWebsocketInputItemsByID(requestJSON)
384385
updatedLastRequest = bytes.Clone(requestJSON)
385386
previousLastRequest := bytes.Clone(lastRequest)
386387
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
@@ -582,6 +583,10 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
582583
if errDedupeFunctionCalls == nil {
583584
mergedInput = dedupedInput
584585
}
586+
dedupedInput, errDedupeItemIDs := dedupeInputItemsByID(mergedInput)
587+
if errDedupeItemIDs == nil {
588+
mergedInput = dedupedInput
589+
}
585590

586591
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
587592
if errDelete != nil {
@@ -697,6 +702,64 @@ func dedupeFunctionCallsByCallID(rawArray string) (string, error) {
697702
return string(out), nil
698703
}
699704

705+
func dedupeResponsesWebsocketInputItemsByID(payload []byte) []byte {
706+
input := gjson.GetBytes(payload, "input")
707+
if !input.Exists() || !input.IsArray() {
708+
return payload
709+
}
710+
dedupedInput, errDedupe := dedupeInputItemsByID(input.Raw)
711+
if errDedupe != nil || dedupedInput == input.Raw {
712+
return payload
713+
}
714+
updated, errSet := sjson.SetRawBytes(payload, "input", []byte(dedupedInput))
715+
if errSet != nil {
716+
return payload
717+
}
718+
return updated
719+
}
720+
721+
func dedupeInputItemsByID(rawArray string) (string, error) {
722+
rawArray = strings.TrimSpace(rawArray)
723+
if rawArray == "" {
724+
return "[]", nil
725+
}
726+
var items []json.RawMessage
727+
if errUnmarshal := json.Unmarshal([]byte(rawArray), &items); errUnmarshal != nil {
728+
return "", errUnmarshal
729+
}
730+
731+
lastIndexByID := make(map[string]int, len(items))
732+
for i, item := range items {
733+
if len(item) == 0 {
734+
continue
735+
}
736+
itemID := strings.TrimSpace(gjson.GetBytes(item, "id").String())
737+
if itemID != "" {
738+
lastIndexByID[itemID] = i
739+
}
740+
}
741+
742+
filtered := make([]json.RawMessage, 0, len(items))
743+
for i, item := range items {
744+
if len(item) == 0 {
745+
continue
746+
}
747+
itemID := strings.TrimSpace(gjson.GetBytes(item, "id").String())
748+
if itemID != "" {
749+
if lastIndexByID[itemID] != i {
750+
continue
751+
}
752+
}
753+
filtered = append(filtered, item)
754+
}
755+
756+
out, errMarshal := json.Marshal(filtered)
757+
if errMarshal != nil {
758+
return "", errMarshal
759+
}
760+
return string(out), nil
761+
}
762+
700763
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
701764
if len(attributes) > 0 {
702765
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {

sdk/api/handlers/openai/openai_responses_websocket_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,30 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateFunctionCallsByCallID(t
16031603
}
16041604
}
16051605

1606+
func TestNormalizeResponsesWebsocketRequestDropsDuplicateInputItemsByID(t *testing.T) {
1607+
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1","role":"user"}]}`)
1608+
lastResponseOutput := []byte(`[
1609+
{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"}
1610+
]`)
1611+
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call","id":"fc-1","call_id":"call-2","name":"tool"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-2"}]}`)
1612+
1613+
normalized, _, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false, true)
1614+
if errMsg != nil {
1615+
t.Fatalf("unexpected error: %v", errMsg.Error)
1616+
}
1617+
1618+
items := gjson.GetBytes(normalized, "input").Array()
1619+
if len(items) != 3 {
1620+
t.Fatalf("merged input len = %d, want 3: %s", len(items), normalized)
1621+
}
1622+
if items[0].Get("id").String() != "msg-1" ||
1623+
items[1].Get("id").String() != "fc-1" ||
1624+
items[1].Get("call_id").String() != "call-2" ||
1625+
items[2].Get("id").String() != "tool-out-1" {
1626+
t.Fatalf("unexpected merged input order: %s", normalized)
1627+
}
1628+
}
1629+
16061630
func TestNormalizeResponsesWebsocketRequestTreatsCustomToolTranscriptReplacementAsReset(t *testing.T) {
16071631
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"apply_patch"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
16081632
lastResponseOutput := []byte(`[
@@ -1654,6 +1678,22 @@ func TestNormalizeResponsesWebsocketRequestDropsDuplicateCustomToolCallsByCallID
16541678
}
16551679
}
16561680

1681+
func TestDedupeResponsesWebsocketInputItemsByIDAfterRepair(t *testing.T) {
1682+
payload := []byte(`{"input":[{"type":"custom_tool_call","id":"ctc-1","call_id":"call-1","name":"tool"},{"type":"custom_tool_call","id":"ctc-1","call_id":"call-2","name":"tool"},{"type":"custom_tool_call_output","id":"tool-out-1","call_id":"call-2"}]}`)
1683+
1684+
deduped := dedupeResponsesWebsocketInputItemsByID(payload)
1685+
1686+
items := gjson.GetBytes(deduped, "input").Array()
1687+
if len(items) != 2 {
1688+
t.Fatalf("deduped input len = %d, want 2: %s", len(items), deduped)
1689+
}
1690+
if items[0].Get("id").String() != "ctc-1" ||
1691+
items[0].Get("call_id").String() != "call-2" ||
1692+
items[1].Get("id").String() != "tool-out-1" {
1693+
t.Fatalf("unexpected deduped input: %s", deduped)
1694+
}
1695+
}
1696+
16571697
func TestResponsesWebsocketCompactionResetsTurnStateOnCustomToolTranscriptReplacement(t *testing.T) {
16581698
gin.SetMode(gin.TestMode)
16591699

0 commit comments

Comments
 (0)