@@ -691,7 +691,7 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForOrphanOutput(t *te
691691 }
692692}
693693
694- func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOutput (t * testing.T ) {
694+ func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseOutputIncremental (t * testing.T ) {
695695 outputCache := newWebsocketToolOutputCache (time .Minute , 10 )
696696 callCache := newWebsocketToolOutputCache (time .Minute , 10 )
697697 sessionKey := "session-1"
@@ -705,17 +705,39 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCallForPreviousResponseOu
705705 t .Fatalf ("previous_response_id = %q, want resp-latest" , got )
706706 }
707707 input := gjson .GetBytes (repaired , "input" ).Array ()
708- if len (input ) != 3 {
709- t .Fatalf ("repaired input len = %d, want 3 : %s" , len (input ), repaired )
708+ if len (input ) != 2 {
709+ t .Fatalf ("repaired input len = %d, want 2 : %s" , len (input ), repaired )
710710 }
711- if input [0 ].Get ("type" ).String () != "function_call " || input [0 ].Get ("call_id" ).String () != "call-1" {
712- t .Fatalf ("missing inserted call : %s" , input [0 ].Raw )
711+ if input [0 ].Get ("type" ).String () != "function_call_output " || input [0 ].Get ("call_id" ).String () != "call-1" {
712+ t .Fatalf ("unexpected output item : %s" , input [0 ].Raw )
713713 }
714- if input [1 ].Get ("type" ).String () != "function_call_output " || input [1 ].Get ("call_id " ).String () != "call -1" {
715- t .Fatalf ("unexpected output item: %s" , input [1 ].Raw )
714+ if input [1 ].Get ("type" ).String () != "message " || input [1 ].Get ("id " ).String () != "msg -1" {
715+ t .Fatalf ("unexpected trailing item: %s" , input [1 ].Raw )
716716 }
717- if input [2 ].Get ("type" ).String () != "message" || input [2 ].Get ("id" ).String () != "msg-1" {
718- t .Fatalf ("unexpected trailing item: %s" , input [2 ].Raw )
717+ }
718+
719+ func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCallIncremental (t * testing.T ) {
720+ outputCache := newWebsocketToolOutputCache (time .Minute , 10 )
721+ callCache := newWebsocketToolOutputCache (time .Minute , 10 )
722+ sessionKey := "session-1"
723+
724+ outputCache .record (sessionKey , "call-1" , []byte (`{"type":"function_call_output","call_id":"call-1","id":"tool-out-1","output":"ok"}` ))
725+
726+ raw := []byte (`{"previous_response_id":"resp-latest","input":[{"type":"function_call","id":"fc-1","call_id":"call-1","name":"tool"},{"type":"message","id":"msg-1"}]}` )
727+ repaired := repairResponsesWebsocketToolCallsWithCaches (outputCache , callCache , sessionKey , raw )
728+
729+ if got := gjson .GetBytes (repaired , "previous_response_id" ).String (); got != "resp-latest" {
730+ t .Fatalf ("previous_response_id = %q, want resp-latest" , got )
731+ }
732+ input := gjson .GetBytes (repaired , "input" ).Array ()
733+ if len (input ) != 2 {
734+ t .Fatalf ("repaired input len = %d, want 2: %s" , len (input ), repaired )
735+ }
736+ if input [0 ].Get ("type" ).String () != "function_call" || input [0 ].Get ("call_id" ).String () != "call-1" {
737+ t .Fatalf ("unexpected call item: %s" , input [0 ].Raw )
738+ }
739+ if input [1 ].Get ("type" ).String () != "message" || input [1 ].Get ("id" ).String () != "msg-1" {
740+ t .Fatalf ("unexpected trailing item: %s" , input [1 ].Raw )
719741 }
720742}
721743
@@ -805,6 +827,31 @@ func TestRepairResponsesWebsocketToolCallsInsertsCachedCustomToolCallForOrphanOu
805827 }
806828}
807829
830+ func TestRepairResponsesWebsocketToolCallsKeepsPreviousResponseCustomToolOutputIncremental (t * testing.T ) {
831+ outputCache := newWebsocketToolOutputCache (time .Minute , 10 )
832+ callCache := newWebsocketToolOutputCache (time .Minute , 10 )
833+ sessionKey := "session-1"
834+
835+ callCache .record (sessionKey , "call-1" , []byte (`{"type":"custom_tool_call","call_id":"call-1","name":"apply_patch"}` ))
836+
837+ raw := []byte (`{"previous_response_id":"resp-latest","input":[{"type":"custom_tool_call_output","call_id":"call-1","output":"ok"},{"type":"message","id":"msg-1"}]}` )
838+ repaired := repairResponsesWebsocketToolCallsWithCaches (outputCache , callCache , sessionKey , raw )
839+
840+ if got := gjson .GetBytes (repaired , "previous_response_id" ).String (); got != "resp-latest" {
841+ t .Fatalf ("previous_response_id = %q, want resp-latest" , got )
842+ }
843+ input := gjson .GetBytes (repaired , "input" ).Array ()
844+ if len (input ) != 2 {
845+ t .Fatalf ("repaired input len = %d, want 2: %s" , len (input ), repaired )
846+ }
847+ if input [0 ].Get ("type" ).String () != "custom_tool_call_output" || input [0 ].Get ("call_id" ).String () != "call-1" {
848+ t .Fatalf ("unexpected output item: %s" , input [0 ].Raw )
849+ }
850+ if input [1 ].Get ("type" ).String () != "message" || input [1 ].Get ("id" ).String () != "msg-1" {
851+ t .Fatalf ("unexpected trailing item: %s" , input [1 ].Raw )
852+ }
853+ }
854+
808855func TestRepairResponsesWebsocketToolCallsDropsOrphanCustomToolOutputWhenCallMissing (t * testing.T ) {
809856 outputCache := newWebsocketToolOutputCache (time .Minute , 10 )
810857 callCache := newWebsocketToolOutputCache (time .Minute , 10 )
0 commit comments