@@ -69,6 +69,22 @@ type websocketAuthCaptureExecutor struct {
6969 authIDs []string
7070}
7171
72+ type websocketPinnedFailoverExecutor struct {
73+ mu sync.Mutex
74+ authIDs []string
75+ calls map [string ]int
76+ payloads map [string ][][]byte
77+ }
78+
79+ type websocketPinnedFailoverStatusError struct {
80+ status int
81+ msg string
82+ }
83+
84+ func (e websocketPinnedFailoverStatusError ) Error () string { return e .msg }
85+
86+ func (e websocketPinnedFailoverStatusError ) StatusCode () int { return e .status }
87+
7288func (e * websocketAuthCaptureExecutor ) Identifier () string { return "test-provider" }
7389
7490func (e * websocketAuthCaptureExecutor ) Execute (context.Context , * coreauth.Auth , coreexecutor.Request , coreexecutor.Options ) (coreexecutor.Response , error ) {
@@ -106,6 +122,76 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
106122 return append ([]string (nil ), e .authIDs ... )
107123}
108124
125+ func (e * websocketPinnedFailoverExecutor ) Identifier () string { return "test-provider" }
126+
127+ func (e * websocketPinnedFailoverExecutor ) Execute (context.Context , * coreauth.Auth , coreexecutor.Request , coreexecutor.Options ) (coreexecutor.Response , error ) {
128+ return coreexecutor.Response {}, errors .New ("not implemented" )
129+ }
130+
131+ func (e * websocketPinnedFailoverExecutor ) ExecuteStream (_ context.Context , auth * coreauth.Auth , req coreexecutor.Request , _ coreexecutor.Options ) (* coreexecutor.StreamResult , error ) {
132+ authID := ""
133+ if auth != nil {
134+ authID = auth .ID
135+ }
136+
137+ e .mu .Lock ()
138+ if e .calls == nil {
139+ e .calls = make (map [string ]int )
140+ }
141+ if e .payloads == nil {
142+ e .payloads = make (map [string ][][]byte )
143+ }
144+ e .authIDs = append (e .authIDs , authID )
145+ e .calls [authID ]++
146+ call := e .calls [authID ]
147+ e .payloads [authID ] = append (e .payloads [authID ], bytes .Clone (req .Payload ))
148+ e .mu .Unlock ()
149+
150+ if authID == "auth-a" && call == 2 {
151+ chunks := make (chan coreexecutor.StreamChunk , 1 )
152+ chunks <- coreexecutor.StreamChunk {Err : websocketPinnedFailoverStatusError {
153+ status : http .StatusTooManyRequests ,
154+ msg : `{"error":{"message":"quota exhausted","type":"rate_limit_error","code":"rate_limit_exceeded"}}` ,
155+ }}
156+ close (chunks )
157+ return & coreexecutor.StreamResult {Chunks : chunks }, nil
158+ }
159+
160+ chunks := make (chan coreexecutor.StreamChunk , 1 )
161+ chunks <- coreexecutor.StreamChunk {Payload : []byte (fmt .Sprintf (`{"type":"response.completed","response":{"id":"resp-%s-%d","output":[{"type":"message","id":"out-%s-%d"}]}}` , authID , call , authID , call ))}
162+ close (chunks )
163+ return & coreexecutor.StreamResult {Chunks : chunks }, nil
164+ }
165+
166+ func (e * websocketPinnedFailoverExecutor ) Refresh (_ context.Context , auth * coreauth.Auth ) (* coreauth.Auth , error ) {
167+ return auth , nil
168+ }
169+
170+ func (e * websocketPinnedFailoverExecutor ) CountTokens (context.Context , * coreauth.Auth , coreexecutor.Request , coreexecutor.Options ) (coreexecutor.Response , error ) {
171+ return coreexecutor.Response {}, errors .New ("not implemented" )
172+ }
173+
174+ func (e * websocketPinnedFailoverExecutor ) HttpRequest (context.Context , * coreauth.Auth , * http.Request ) (* http.Response , error ) {
175+ return nil , errors .New ("not implemented" )
176+ }
177+
178+ func (e * websocketPinnedFailoverExecutor ) AuthIDs () []string {
179+ e .mu .Lock ()
180+ defer e .mu .Unlock ()
181+ return append ([]string (nil ), e .authIDs ... )
182+ }
183+
184+ func (e * websocketPinnedFailoverExecutor ) Payloads (authID string ) [][]byte {
185+ e .mu .Lock ()
186+ defer e .mu .Unlock ()
187+ src := e .payloads [authID ]
188+ out := make ([][]byte , len (src ))
189+ for i := range src {
190+ out [i ] = bytes .Clone (src [i ])
191+ }
192+ return out
193+ }
194+
109195func (e * websocketCaptureExecutor ) Identifier () string { return "test-provider" }
110196
111197func (e * websocketCaptureExecutor ) Execute (context.Context , * coreauth.Auth , coreexecutor.Request , coreexecutor.Options ) (coreexecutor.Response , error ) {
@@ -681,7 +767,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
681767 close (errCh )
682768
683769 var timelineLog strings.Builder
684- completedOutput , err := (* OpenAIResponsesAPIHandler )(nil ).forwardResponsesWebsocket (
770+ completedOutput , errMsg , err := (* OpenAIResponsesAPIHandler )(nil ).forwardResponsesWebsocket (
685771 ctx ,
686772 conn ,
687773 func (... interface {}) {},
@@ -694,6 +780,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
694780 serverErrCh <- err
695781 return
696782 }
783+ if errMsg != nil {
784+ serverErrCh <- fmt .Errorf ("unexpected websocket error message: %v" , errMsg .Error )
785+ return
786+ }
697787 if gjson .GetBytes (completedOutput , "0.id" ).String () != "out-1" {
698788 serverErrCh <- errors .New ("completed output not captured" )
699789 return
@@ -760,7 +850,7 @@ func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing
760850 return
761851 }
762852
763- _ , err = (* OpenAIResponsesAPIHandler )(nil ).forwardResponsesWebsocket (
853+ _ , _ , err = (* OpenAIResponsesAPIHandler )(nil ).forwardResponsesWebsocket (
764854 ctx ,
765855 conn ,
766856 func (... interface {}) {},
@@ -1113,6 +1203,99 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
11131203 }
11141204}
11151205
1206+ func TestResponsesWebsocketReleasesPinnedAuthAfterQuotaError (t * testing.T ) {
1207+ gin .SetMode (gin .TestMode )
1208+
1209+ selector := & orderedWebsocketSelector {order : []string {"auth-a" , "auth-b" }}
1210+ executor := & websocketPinnedFailoverExecutor {}
1211+ manager := coreauth .NewManager (nil , selector , nil )
1212+ manager .RegisterExecutor (executor )
1213+
1214+ authA := & coreauth.Auth {
1215+ ID : "auth-a" ,
1216+ Provider : executor .Identifier (),
1217+ Status : coreauth .StatusActive ,
1218+ Attributes : map [string ]string {"websockets" : "true" },
1219+ }
1220+ if _ , err := manager .Register (context .Background (), authA ); err != nil {
1221+ t .Fatalf ("Register auth A: %v" , err )
1222+ }
1223+ authB := & coreauth.Auth {
1224+ ID : "auth-b" ,
1225+ Provider : executor .Identifier (),
1226+ Status : coreauth .StatusActive ,
1227+ Attributes : map [string ]string {"websockets" : "true" },
1228+ }
1229+ if _ , err := manager .Register (context .Background (), authB ); err != nil {
1230+ t .Fatalf ("Register auth B: %v" , err )
1231+ }
1232+
1233+ registry .GetGlobalRegistry ().RegisterClient (authA .ID , authA .Provider , []* registry.ModelInfo {{ID : "quota-model" }})
1234+ registry .GetGlobalRegistry ().RegisterClient (authB .ID , authB .Provider , []* registry.ModelInfo {{ID : "quota-model" }})
1235+ t .Cleanup (func () {
1236+ registry .GetGlobalRegistry ().UnregisterClient (authA .ID )
1237+ registry .GetGlobalRegistry ().UnregisterClient (authB .ID )
1238+ })
1239+
1240+ base := handlers .NewBaseAPIHandlers (& sdkconfig.SDKConfig {}, manager )
1241+ h := NewOpenAIResponsesAPIHandler (base )
1242+ router := gin .New ()
1243+ router .GET ("/v1/responses/ws" , h .ResponsesWebsocket )
1244+
1245+ server := httptest .NewServer (router )
1246+ defer server .Close ()
1247+
1248+ wsURL := "ws" + strings .TrimPrefix (server .URL , "http" ) + "/v1/responses/ws"
1249+ conn , _ , err := websocket .DefaultDialer .Dial (wsURL , nil )
1250+ if err != nil {
1251+ t .Fatalf ("dial websocket: %v" , err )
1252+ }
1253+ defer func () {
1254+ if errClose := conn .Close (); errClose != nil {
1255+ t .Fatalf ("close websocket: %v" , errClose )
1256+ }
1257+ }()
1258+
1259+ requests := []string {
1260+ `{"type":"response.create","model":"quota-model","input":[{"type":"message","id":"msg-1"}]}` ,
1261+ `{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-2"}]}` ,
1262+ `{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-3"}]}` ,
1263+ }
1264+ wantTypes := []string {wsEventTypeCompleted , wsEventTypeError , wsEventTypeCompleted }
1265+ for i := range requests {
1266+ if errWrite := conn .WriteMessage (websocket .TextMessage , []byte (requests [i ])); errWrite != nil {
1267+ t .Fatalf ("write websocket message %d: %v" , i + 1 , errWrite )
1268+ }
1269+ _ , payload , errReadMessage := conn .ReadMessage ()
1270+ if errReadMessage != nil {
1271+ t .Fatalf ("read websocket message %d: %v" , i + 1 , errReadMessage )
1272+ }
1273+ if got := gjson .GetBytes (payload , "type" ).String (); got != wantTypes [i ] {
1274+ t .Fatalf ("message %d payload type = %s, want %s: %s" , i + 1 , got , wantTypes [i ], payload )
1275+ }
1276+ if i == 1 && int (gjson .GetBytes (payload , "status" ).Int ()) != http .StatusTooManyRequests {
1277+ t .Fatalf ("quota payload status = %d, want %d: %s" , gjson .GetBytes (payload , "status" ).Int (), http .StatusTooManyRequests , payload )
1278+ }
1279+ }
1280+
1281+ if got := executor .AuthIDs (); len (got ) != 3 || got [0 ] != "auth-a" || got [1 ] != "auth-a" || got [2 ] != "auth-b" {
1282+ t .Fatalf ("selected auth IDs = %v, want [auth-a auth-a auth-b]" , got )
1283+ }
1284+
1285+ authBPayloads := executor .Payloads ("auth-b" )
1286+ if len (authBPayloads ) != 1 {
1287+ t .Fatalf ("auth-b payload count = %d, want 1" , len (authBPayloads ))
1288+ }
1289+ authBPayload := authBPayloads [0 ]
1290+ if gjson .GetBytes (authBPayload , "previous_response_id" ).Exists () {
1291+ t .Fatalf ("previous_response_id leaked after auth failover: %s" , authBPayload )
1292+ }
1293+ authBInput := gjson .GetBytes (authBPayload , "input" ).Raw
1294+ if ! strings .Contains (authBInput , `"id":"msg-1"` ) || ! strings .Contains (authBInput , `"id":"msg-3"` ) {
1295+ t .Fatalf ("auth-b replay input missing expected transcript items: %s" , authBInput )
1296+ }
1297+ }
1298+
11161299func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset (t * testing.T ) {
11171300 lastRequest := []byte (`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}` )
11181301 lastResponseOutput := []byte (`[
0 commit comments