Skip to content

Commit ed52c61

Browse files
committed
test(websocket, api): add unit tests for response ID injection and handling of pending tool calls
- Introduced test scenarios to validate `previous_response_id` injection during incremental and non-incremental requests. - Verified behavior for pending tool calls, including proper inclusion or exclusion in websocket requests. - Updated websocket handling logic to track `lastResponseID` and `pendingToolCallIDs`. - Added utility functions for pending tool call validation and cleanup.
1 parent 538e341 commit ed52c61

2 files changed

Lines changed: 373 additions & 16 deletions

File tree

sdk/api/handlers/openai/openai_responses_websocket.go

Lines changed: 127 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"net/http"
10+
"sort"
1011
"strconv"
1112
"strings"
1213
"time"
@@ -267,6 +268,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
267268

268269
var lastRequest []byte
269270
lastResponseOutput := []byte("[]")
271+
lastResponseID := ""
272+
var lastResponsePendingToolCallIDs []string
270273
pinnedAuthID := ""
271274
sessionAuthByID := func(authID string) (*coreauth.Auth, bool) {
272275
if h == nil || h.AuthManager == nil {
@@ -335,10 +338,12 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
335338
var requestJSON []byte
336339
var updatedLastRequest []byte
337340
var errMsg *interfaces.ErrorMessage
338-
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
341+
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithIncrementalState(
339342
payload,
340343
lastRequest,
341344
lastResponseOutput,
345+
lastResponseID,
346+
lastResponsePendingToolCallIDs,
342347
allowIncrementalInputWithPreviousResponseID,
343348
allowCompactionReplayBypass,
344349
)
@@ -373,6 +378,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
373378
}
374379
lastRequest = updatedLastRequest
375380
lastResponseOutput = []byte("[]")
381+
lastResponseID = ""
382+
lastResponsePendingToolCallIDs = nil
376383
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, wsTimelineLog, passthroughSessionID); errWrite != nil {
377384
wsTerminateErr = errWrite
378385
return
@@ -385,6 +392,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
385392
updatedLastRequest = bytes.Clone(requestJSON)
386393
previousLastRequest := bytes.Clone(lastRequest)
387394
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
395+
previousLastResponseID := lastResponseID
396+
previousLastResponsePendingToolCallIDs := append([]string(nil), lastResponsePendingToolCallIDs...)
388397
forcedTranscriptReplay := forceTranscriptReplayNextRequest
389398
lastRequest = updatedLastRequest
390399
if forcedTranscriptReplay {
@@ -414,7 +423,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
414423
}
415424
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
416425

417-
completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, wsTimelineLog, passthroughSessionID)
426+
completedOutput, completedResponseID, completedPendingToolCallIDs, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, wsTimelineLog, passthroughSessionID)
418427
if errForward != nil {
419428
wsTerminateErr = errForward
420429
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
@@ -425,9 +434,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
425434
forceTranscriptReplayNextRequest = true
426435
lastRequest = previousLastRequest
427436
lastResponseOutput = previousLastResponseOutput
437+
lastResponseID = previousLastResponseID
438+
lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs
428439
continue
429440
}
430441
lastResponseOutput = completedOutput
442+
lastResponseID = strings.TrimSpace(completedResponseID)
443+
lastResponsePendingToolCallIDs = append([]string(nil), completedPendingToolCallIDs...)
431444
}
432445
}
433446

@@ -457,17 +470,25 @@ func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, last
457470
}
458471

459472
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
473+
return normalizeResponsesWebsocketRequestWithLastResponseID(rawJSON, lastRequest, lastResponseOutput, "", allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
474+
}
475+
476+
func normalizeResponsesWebsocketRequestWithLastResponseID(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, lastResponseID string, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
477+
return normalizeResponsesWebsocketRequestWithIncrementalState(rawJSON, lastRequest, lastResponseOutput, lastResponseID, nil, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
478+
}
479+
480+
func normalizeResponsesWebsocketRequestWithIncrementalState(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, lastResponseID string, lastResponsePendingToolCallIDs []string, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
460481
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
461482
switch requestType {
462483
case wsRequestTypeCreate:
463484
// log.Infof("responses websocket: response.create request")
464485
if len(lastRequest) == 0 {
465486
return normalizeResponseCreateRequest(rawJSON)
466487
}
467-
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
488+
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, lastResponseID, lastResponsePendingToolCallIDs, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
468489
case wsRequestTypeAppend:
469490
// log.Infof("responses websocket: response.append request")
470-
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
491+
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, lastResponseID, lastResponsePendingToolCallIDs, allowIncrementalInputWithPreviousResponseID, allowCompactionReplayBypass)
471492
default:
472493
return nil, lastRequest, &interfaces.ErrorMessage{
473494
StatusCode: http.StatusBadRequest,
@@ -496,7 +517,7 @@ func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces
496517
return normalized, bytes.Clone(normalized), nil
497518
}
498519

499-
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
520+
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, lastResponseID string, lastResponsePendingToolCallIDs []string, allowIncrementalInputWithPreviousResponseID bool, allowCompactionReplayBypass bool) ([]byte, []byte, *interfaces.ErrorMessage) {
500521
if len(lastRequest) == 0 {
501522
return nil, lastRequest, &interfaces.ErrorMessage{
502523
StatusCode: http.StatusBadRequest,
@@ -524,11 +545,20 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
524545
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
525546
// Do not expand it into a full input transcript; upstream expects the incremental payload.
526547
if allowIncrementalInputWithPreviousResponseID {
527-
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
548+
prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String())
549+
if prev == "" {
550+
if !inputSatisfiesPendingToolCalls(nextInput, lastResponsePendingToolCallIDs) {
551+
normalized := normalizeResponseTranscriptReplacement(rawJSON, lastRequest)
552+
return normalized, bytes.Clone(normalized), nil
553+
}
554+
prev = strings.TrimSpace(lastResponseID)
555+
}
556+
if prev != "" {
528557
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
529558
if errDelete != nil {
530559
normalized = bytes.Clone(rawJSON)
531560
}
561+
normalized, _ = sjson.SetBytes(normalized, "previous_response_id", prev)
532562
if !gjson.GetBytes(normalized, "model").Exists() {
533563
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
534564
if modelName != "" {
@@ -644,6 +674,35 @@ func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bo
644674
return false
645675
}
646676

677+
func inputSatisfiesPendingToolCalls(input gjson.Result, pendingCallIDs []string) bool {
678+
if len(pendingCallIDs) == 0 {
679+
return true
680+
}
681+
if !input.IsArray() {
682+
return false
683+
}
684+
outputs := make(map[string]struct{}, len(pendingCallIDs))
685+
for _, item := range input.Array() {
686+
switch strings.TrimSpace(item.Get("type").String()) {
687+
case "function_call_output", "custom_tool_call_output":
688+
callID := strings.TrimSpace(item.Get("call_id").String())
689+
if callID != "" {
690+
outputs[callID] = struct{}{}
691+
}
692+
}
693+
}
694+
for _, callID := range pendingCallIDs {
695+
callID = strings.TrimSpace(callID)
696+
if callID == "" {
697+
continue
698+
}
699+
if _, ok := outputs[callID]; !ok {
700+
return false
701+
}
702+
}
703+
return true
704+
}
705+
647706
func normalizeResponseTranscriptReplacement(rawJSON []byte, lastRequest []byte) []byte {
648707
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
649708
if errDelete != nil {
@@ -1138,9 +1197,11 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
11381197
errs <-chan *interfaces.ErrorMessage,
11391198
wsTimelineLog websocketTimelineAppender,
11401199
sessionID string,
1141-
) ([]byte, *interfaces.ErrorMessage, error) {
1200+
) ([]byte, string, []string, *interfaces.ErrorMessage, error) {
11421201
completed := false
11431202
completedOutput := []byte("[]")
1203+
completedResponseID := ""
1204+
pendingToolCallIDs := make(map[string]struct{})
11441205
downstreamSessionKey := ""
11451206
if c != nil && c.Request != nil {
11461207
downstreamSessionKey = websocketDownstreamSessionKey(c.Request)
@@ -1150,7 +1211,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
11501211
select {
11511212
case <-c.Request.Context().Done():
11521213
cancel(c.Request.Context().Err())
1153-
return completedOutput, nil, c.Request.Context().Err()
1214+
return completedOutput, completedResponseID, sortedStringSet(pendingToolCallIDs), nil, c.Request.Context().Err()
11541215
case errMsg, ok := <-errs:
11551216
if !ok {
11561217
errs = nil
@@ -1175,15 +1236,15 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
11751236
// errWrite,
11761237
// )
11771238
cancel(errMsg.Error)
1178-
return completedOutput, errMsg, errWrite
1239+
return completedOutput, completedResponseID, sortedStringSet(pendingToolCallIDs), errMsg, errWrite
11791240
}
11801241
}
11811242
if errMsg != nil {
11821243
cancel(errMsg.Error)
11831244
} else {
11841245
cancel(nil)
11851246
}
1186-
return completedOutput, errMsg, nil
1247+
return completedOutput, completedResponseID, sortedStringSet(pendingToolCallIDs), errMsg, nil
11871248
case chunk, ok := <-data:
11881249
if !ok {
11891250
if !completed {
@@ -1209,22 +1270,24 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
12091270
errWrite,
12101271
)
12111272
cancel(errMsg.Error)
1212-
return completedOutput, errMsg, errWrite
1273+
return completedOutput, completedResponseID, sortedStringSet(pendingToolCallIDs), errMsg, errWrite
12131274
}
12141275
cancel(errMsg.Error)
1215-
return completedOutput, errMsg, nil
1276+
return completedOutput, completedResponseID, sortedStringSet(pendingToolCallIDs), errMsg, nil
12161277
}
12171278
cancel(nil)
1218-
return completedOutput, nil, nil
1279+
return completedOutput, completedResponseID, sortedStringSet(pendingToolCallIDs), nil, nil
12191280
}
12201281

12211282
payloads := websocketJSONPayloadsFromChunk(chunk)
12221283
for i := range payloads {
12231284
recordResponsesWebsocketToolCallsFromPayload(downstreamSessionKey, payloads[i])
1285+
recordPendingToolCallIDsFromPayload(pendingToolCallIDs, payloads[i])
12241286
eventType := gjson.GetBytes(payloads[i], "type").String()
12251287
if eventType == wsEventTypeCompleted {
12261288
completed = true
12271289
completedOutput = responseCompletedOutputFromPayload(payloads[i])
1290+
completedResponseID = responseCompletedIDFromPayload(payloads[i])
12281291
}
12291292
markAPIResponseTimestamp(c)
12301293
// log.Infof(
@@ -1242,7 +1305,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
12421305
errWrite,
12431306
)
12441307
cancel(errWrite)
1245-
return completedOutput, nil, errWrite
1308+
return completedOutput, completedResponseID, sortedStringSet(pendingToolCallIDs), nil, errWrite
12461309
}
12471310
}
12481311
}
@@ -1275,6 +1338,56 @@ func responseCompletedOutputFromPayload(payload []byte) []byte {
12751338
return []byte("[]")
12761339
}
12771340

1341+
func responseCompletedIDFromPayload(payload []byte) string {
1342+
return strings.TrimSpace(gjson.GetBytes(payload, "response.id").String())
1343+
}
1344+
1345+
func recordPendingToolCallIDsFromPayload(pending map[string]struct{}, payload []byte) {
1346+
if pending == nil || len(payload) == 0 {
1347+
return
1348+
}
1349+
updatePendingToolCallIDsFromItem(pending, gjson.GetBytes(payload, "item"))
1350+
output := gjson.GetBytes(payload, "response.output")
1351+
if output.IsArray() {
1352+
for _, item := range output.Array() {
1353+
updatePendingToolCallIDsFromItem(pending, item)
1354+
}
1355+
}
1356+
}
1357+
1358+
func updatePendingToolCallIDsFromItem(pending map[string]struct{}, item gjson.Result) {
1359+
if pending == nil || !item.Exists() {
1360+
return
1361+
}
1362+
switch strings.TrimSpace(item.Get("type").String()) {
1363+
case "function_call", "custom_tool_call":
1364+
callID := strings.TrimSpace(item.Get("call_id").String())
1365+
if callID != "" {
1366+
pending[callID] = struct{}{}
1367+
}
1368+
case "function_call_output", "custom_tool_call_output":
1369+
callID := strings.TrimSpace(item.Get("call_id").String())
1370+
if callID != "" {
1371+
delete(pending, callID)
1372+
}
1373+
}
1374+
}
1375+
1376+
func sortedStringSet(values map[string]struct{}) []string {
1377+
if len(values) == 0 {
1378+
return nil
1379+
}
1380+
out := make([]string, 0, len(values))
1381+
for value := range values {
1382+
value = strings.TrimSpace(value)
1383+
if value != "" {
1384+
out = append(out, value)
1385+
}
1386+
}
1387+
sort.Strings(out)
1388+
return out
1389+
}
1390+
12781391
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
12791392
payloads := make([][]byte, 0, 2)
12801393
lines := bytes.Split(chunk, []byte("\n"))

0 commit comments

Comments
 (0)