77 "fmt"
88 "io"
99 "net/http"
10+ "sort"
1011 "strconv"
1112 "strings"
1213 "time"
@@ -267,6 +268,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
267268
268269 var lastRequest []byte
269270 lastResponseOutput := []byte ("[]" )
271+ lastResponseID := ""
272+ var lastResponsePendingToolCallIDs []string
270273 pinnedAuthID := ""
271274 sessionAuthByID := func (authID string ) (* coreauth.Auth , bool ) {
272275 if h == nil || h .AuthManager == nil {
@@ -335,10 +338,12 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
335338 var requestJSON []byte
336339 var updatedLastRequest []byte
337340 var errMsg * interfaces.ErrorMessage
338- requestJSON , updatedLastRequest , errMsg = normalizeResponsesWebsocketRequestWithMode (
341+ requestJSON , updatedLastRequest , errMsg = normalizeResponsesWebsocketRequestWithIncrementalState (
339342 payload ,
340343 lastRequest ,
341344 lastResponseOutput ,
345+ lastResponseID ,
346+ lastResponsePendingToolCallIDs ,
342347 allowIncrementalInputWithPreviousResponseID ,
343348 allowCompactionReplayBypass ,
344349 )
@@ -373,6 +378,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
373378 }
374379 lastRequest = updatedLastRequest
375380 lastResponseOutput = []byte ("[]" )
381+ lastResponseID = ""
382+ lastResponsePendingToolCallIDs = nil
376383 if errWrite := writeResponsesWebsocketSyntheticPrewarm (c , conn , requestJSON , wsTimelineLog , passthroughSessionID ); errWrite != nil {
377384 wsTerminateErr = errWrite
378385 return
@@ -385,6 +392,8 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
385392 updatedLastRequest = bytes .Clone (requestJSON )
386393 previousLastRequest := bytes .Clone (lastRequest )
387394 previousLastResponseOutput := bytes .Clone (lastResponseOutput )
395+ previousLastResponseID := lastResponseID
396+ previousLastResponsePendingToolCallIDs := append ([]string (nil ), lastResponsePendingToolCallIDs ... )
388397 forcedTranscriptReplay := forceTranscriptReplayNextRequest
389398 lastRequest = updatedLastRequest
390399 if forcedTranscriptReplay {
@@ -414,7 +423,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
414423 }
415424 dataChan , _ , errChan := h .ExecuteStreamWithAuthManager (cliCtx , h .HandlerType (), modelName , requestJSON , "" )
416425
417- completedOutput , forwardErrMsg , errForward := h .forwardResponsesWebsocket (c , conn , cliCancel , dataChan , errChan , wsTimelineLog , passthroughSessionID )
426+ completedOutput , completedResponseID , completedPendingToolCallIDs , forwardErrMsg , errForward := h .forwardResponsesWebsocket (c , conn , cliCancel , dataChan , errChan , wsTimelineLog , passthroughSessionID )
418427 if errForward != nil {
419428 wsTerminateErr = errForward
420429 log .Warnf ("responses websocket: forward failed id=%s error=%v" , passthroughSessionID , errForward )
@@ -425,9 +434,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
425434 forceTranscriptReplayNextRequest = true
426435 lastRequest = previousLastRequest
427436 lastResponseOutput = previousLastResponseOutput
437+ lastResponseID = previousLastResponseID
438+ lastResponsePendingToolCallIDs = previousLastResponsePendingToolCallIDs
428439 continue
429440 }
430441 lastResponseOutput = completedOutput
442+ lastResponseID = strings .TrimSpace (completedResponseID )
443+ lastResponsePendingToolCallIDs = append ([]string (nil ), completedPendingToolCallIDs ... )
431444 }
432445}
433446
@@ -457,17 +470,25 @@ func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, last
457470}
458471
459472func normalizeResponsesWebsocketRequestWithMode (rawJSON []byte , lastRequest []byte , lastResponseOutput []byte , allowIncrementalInputWithPreviousResponseID bool , allowCompactionReplayBypass bool ) ([]byte , []byte , * interfaces.ErrorMessage ) {
473+ return normalizeResponsesWebsocketRequestWithLastResponseID (rawJSON , lastRequest , lastResponseOutput , "" , allowIncrementalInputWithPreviousResponseID , allowCompactionReplayBypass )
474+ }
475+
476+ func normalizeResponsesWebsocketRequestWithLastResponseID (rawJSON []byte , lastRequest []byte , lastResponseOutput []byte , lastResponseID string , allowIncrementalInputWithPreviousResponseID bool , allowCompactionReplayBypass bool ) ([]byte , []byte , * interfaces.ErrorMessage ) {
477+ return normalizeResponsesWebsocketRequestWithIncrementalState (rawJSON , lastRequest , lastResponseOutput , lastResponseID , nil , allowIncrementalInputWithPreviousResponseID , allowCompactionReplayBypass )
478+ }
479+
480+ func normalizeResponsesWebsocketRequestWithIncrementalState (rawJSON []byte , lastRequest []byte , lastResponseOutput []byte , lastResponseID string , lastResponsePendingToolCallIDs []string , allowIncrementalInputWithPreviousResponseID bool , allowCompactionReplayBypass bool ) ([]byte , []byte , * interfaces.ErrorMessage ) {
460481 requestType := strings .TrimSpace (gjson .GetBytes (rawJSON , "type" ).String ())
461482 switch requestType {
462483 case wsRequestTypeCreate :
463484 // log.Infof("responses websocket: response.create request")
464485 if len (lastRequest ) == 0 {
465486 return normalizeResponseCreateRequest (rawJSON )
466487 }
467- return normalizeResponseSubsequentRequest (rawJSON , lastRequest , lastResponseOutput , allowIncrementalInputWithPreviousResponseID , allowCompactionReplayBypass )
488+ return normalizeResponseSubsequentRequest (rawJSON , lastRequest , lastResponseOutput , lastResponseID , lastResponsePendingToolCallIDs , allowIncrementalInputWithPreviousResponseID , allowCompactionReplayBypass )
468489 case wsRequestTypeAppend :
469490 // log.Infof("responses websocket: response.append request")
470- return normalizeResponseSubsequentRequest (rawJSON , lastRequest , lastResponseOutput , allowIncrementalInputWithPreviousResponseID , allowCompactionReplayBypass )
491+ return normalizeResponseSubsequentRequest (rawJSON , lastRequest , lastResponseOutput , lastResponseID , lastResponsePendingToolCallIDs , allowIncrementalInputWithPreviousResponseID , allowCompactionReplayBypass )
471492 default :
472493 return nil , lastRequest , & interfaces.ErrorMessage {
473494 StatusCode : http .StatusBadRequest ,
@@ -496,7 +517,7 @@ func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces
496517 return normalized , bytes .Clone (normalized ), nil
497518}
498519
499- func normalizeResponseSubsequentRequest (rawJSON []byte , lastRequest []byte , lastResponseOutput []byte , allowIncrementalInputWithPreviousResponseID bool , allowCompactionReplayBypass bool ) ([]byte , []byte , * interfaces.ErrorMessage ) {
520+ func normalizeResponseSubsequentRequest (rawJSON []byte , lastRequest []byte , lastResponseOutput []byte , lastResponseID string , lastResponsePendingToolCallIDs [] string , allowIncrementalInputWithPreviousResponseID bool , allowCompactionReplayBypass bool ) ([]byte , []byte , * interfaces.ErrorMessage ) {
500521 if len (lastRequest ) == 0 {
501522 return nil , lastRequest , & interfaces.ErrorMessage {
502523 StatusCode : http .StatusBadRequest ,
@@ -524,11 +545,20 @@ func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, last
524545 // Websocket v2 mode uses response.create with previous_response_id + incremental input.
525546 // Do not expand it into a full input transcript; upstream expects the incremental payload.
526547 if allowIncrementalInputWithPreviousResponseID {
527- if prev := strings .TrimSpace (gjson .GetBytes (rawJSON , "previous_response_id" ).String ()); prev != "" {
548+ prev := strings .TrimSpace (gjson .GetBytes (rawJSON , "previous_response_id" ).String ())
549+ if prev == "" {
550+ if ! inputSatisfiesPendingToolCalls (nextInput , lastResponsePendingToolCallIDs ) {
551+ normalized := normalizeResponseTranscriptReplacement (rawJSON , lastRequest )
552+ return normalized , bytes .Clone (normalized ), nil
553+ }
554+ prev = strings .TrimSpace (lastResponseID )
555+ }
556+ if prev != "" {
528557 normalized , errDelete := sjson .DeleteBytes (rawJSON , "type" )
529558 if errDelete != nil {
530559 normalized = bytes .Clone (rawJSON )
531560 }
561+ normalized , _ = sjson .SetBytes (normalized , "previous_response_id" , prev )
532562 if ! gjson .GetBytes (normalized , "model" ).Exists () {
533563 modelName := strings .TrimSpace (gjson .GetBytes (lastRequest , "model" ).String ())
534564 if modelName != "" {
@@ -644,6 +674,35 @@ func shouldReplaceWebsocketTranscript(rawJSON []byte, nextInput gjson.Result) bo
644674 return false
645675}
646676
677+ func inputSatisfiesPendingToolCalls (input gjson.Result , pendingCallIDs []string ) bool {
678+ if len (pendingCallIDs ) == 0 {
679+ return true
680+ }
681+ if ! input .IsArray () {
682+ return false
683+ }
684+ outputs := make (map [string ]struct {}, len (pendingCallIDs ))
685+ for _ , item := range input .Array () {
686+ switch strings .TrimSpace (item .Get ("type" ).String ()) {
687+ case "function_call_output" , "custom_tool_call_output" :
688+ callID := strings .TrimSpace (item .Get ("call_id" ).String ())
689+ if callID != "" {
690+ outputs [callID ] = struct {}{}
691+ }
692+ }
693+ }
694+ for _ , callID := range pendingCallIDs {
695+ callID = strings .TrimSpace (callID )
696+ if callID == "" {
697+ continue
698+ }
699+ if _ , ok := outputs [callID ]; ! ok {
700+ return false
701+ }
702+ }
703+ return true
704+ }
705+
647706func normalizeResponseTranscriptReplacement (rawJSON []byte , lastRequest []byte ) []byte {
648707 normalized , errDelete := sjson .DeleteBytes (rawJSON , "type" )
649708 if errDelete != nil {
@@ -1138,9 +1197,11 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
11381197 errs <- chan * interfaces.ErrorMessage ,
11391198 wsTimelineLog websocketTimelineAppender ,
11401199 sessionID string ,
1141- ) ([]byte , * interfaces.ErrorMessage , error ) {
1200+ ) ([]byte , string , [] string , * interfaces.ErrorMessage , error ) {
11421201 completed := false
11431202 completedOutput := []byte ("[]" )
1203+ completedResponseID := ""
1204+ pendingToolCallIDs := make (map [string ]struct {})
11441205 downstreamSessionKey := ""
11451206 if c != nil && c .Request != nil {
11461207 downstreamSessionKey = websocketDownstreamSessionKey (c .Request )
@@ -1150,7 +1211,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
11501211 select {
11511212 case <- c .Request .Context ().Done ():
11521213 cancel (c .Request .Context ().Err ())
1153- return completedOutput , nil , c .Request .Context ().Err ()
1214+ return completedOutput , completedResponseID , sortedStringSet ( pendingToolCallIDs ), nil , c .Request .Context ().Err ()
11541215 case errMsg , ok := <- errs :
11551216 if ! ok {
11561217 errs = nil
@@ -1175,15 +1236,15 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
11751236 // errWrite,
11761237 // )
11771238 cancel (errMsg .Error )
1178- return completedOutput , errMsg , errWrite
1239+ return completedOutput , completedResponseID , sortedStringSet ( pendingToolCallIDs ), errMsg , errWrite
11791240 }
11801241 }
11811242 if errMsg != nil {
11821243 cancel (errMsg .Error )
11831244 } else {
11841245 cancel (nil )
11851246 }
1186- return completedOutput , errMsg , nil
1247+ return completedOutput , completedResponseID , sortedStringSet ( pendingToolCallIDs ), errMsg , nil
11871248 case chunk , ok := <- data :
11881249 if ! ok {
11891250 if ! completed {
@@ -1209,22 +1270,24 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
12091270 errWrite ,
12101271 )
12111272 cancel (errMsg .Error )
1212- return completedOutput , errMsg , errWrite
1273+ return completedOutput , completedResponseID , sortedStringSet ( pendingToolCallIDs ), errMsg , errWrite
12131274 }
12141275 cancel (errMsg .Error )
1215- return completedOutput , errMsg , nil
1276+ return completedOutput , completedResponseID , sortedStringSet ( pendingToolCallIDs ), errMsg , nil
12161277 }
12171278 cancel (nil )
1218- return completedOutput , nil , nil
1279+ return completedOutput , completedResponseID , sortedStringSet ( pendingToolCallIDs ), nil , nil
12191280 }
12201281
12211282 payloads := websocketJSONPayloadsFromChunk (chunk )
12221283 for i := range payloads {
12231284 recordResponsesWebsocketToolCallsFromPayload (downstreamSessionKey , payloads [i ])
1285+ recordPendingToolCallIDsFromPayload (pendingToolCallIDs , payloads [i ])
12241286 eventType := gjson .GetBytes (payloads [i ], "type" ).String ()
12251287 if eventType == wsEventTypeCompleted {
12261288 completed = true
12271289 completedOutput = responseCompletedOutputFromPayload (payloads [i ])
1290+ completedResponseID = responseCompletedIDFromPayload (payloads [i ])
12281291 }
12291292 markAPIResponseTimestamp (c )
12301293 // log.Infof(
@@ -1242,7 +1305,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
12421305 errWrite ,
12431306 )
12441307 cancel (errWrite )
1245- return completedOutput , nil , errWrite
1308+ return completedOutput , completedResponseID , sortedStringSet ( pendingToolCallIDs ), nil , errWrite
12461309 }
12471310 }
12481311 }
@@ -1275,6 +1338,56 @@ func responseCompletedOutputFromPayload(payload []byte) []byte {
12751338 return []byte ("[]" )
12761339}
12771340
1341+ func responseCompletedIDFromPayload (payload []byte ) string {
1342+ return strings .TrimSpace (gjson .GetBytes (payload , "response.id" ).String ())
1343+ }
1344+
1345+ func recordPendingToolCallIDsFromPayload (pending map [string ]struct {}, payload []byte ) {
1346+ if pending == nil || len (payload ) == 0 {
1347+ return
1348+ }
1349+ updatePendingToolCallIDsFromItem (pending , gjson .GetBytes (payload , "item" ))
1350+ output := gjson .GetBytes (payload , "response.output" )
1351+ if output .IsArray () {
1352+ for _ , item := range output .Array () {
1353+ updatePendingToolCallIDsFromItem (pending , item )
1354+ }
1355+ }
1356+ }
1357+
1358+ func updatePendingToolCallIDsFromItem (pending map [string ]struct {}, item gjson.Result ) {
1359+ if pending == nil || ! item .Exists () {
1360+ return
1361+ }
1362+ switch strings .TrimSpace (item .Get ("type" ).String ()) {
1363+ case "function_call" , "custom_tool_call" :
1364+ callID := strings .TrimSpace (item .Get ("call_id" ).String ())
1365+ if callID != "" {
1366+ pending [callID ] = struct {}{}
1367+ }
1368+ case "function_call_output" , "custom_tool_call_output" :
1369+ callID := strings .TrimSpace (item .Get ("call_id" ).String ())
1370+ if callID != "" {
1371+ delete (pending , callID )
1372+ }
1373+ }
1374+ }
1375+
1376+ func sortedStringSet (values map [string ]struct {}) []string {
1377+ if len (values ) == 0 {
1378+ return nil
1379+ }
1380+ out := make ([]string , 0 , len (values ))
1381+ for value := range values {
1382+ value = strings .TrimSpace (value )
1383+ if value != "" {
1384+ out = append (out , value )
1385+ }
1386+ }
1387+ sort .Strings (out )
1388+ return out
1389+ }
1390+
12781391func websocketJSONPayloadsFromChunk (chunk []byte ) [][]byte {
12791392 payloads := make ([][]byte , 0 , 2 )
12801393 lines := bytes .Split (chunk , []byte ("\n " ))
0 commit comments