Skip to content

Commit 8e6ef3f

Browse files
committed
fix(websocket): ensure state consistency on auth errors in streaming
- Added logic to reset `pinnedAuthID` and replay transcript on unauthorized, forbidden, or throttling errors. - Enhanced error handling in `forwardResponsesWebsocket` with detailed status inspection. - Introduced `shouldReleaseResponsesWebsocketPinnedAuth` to determine auth reset conditions. - Updated state management to preserve prior request and response data during forced replay. Fixed: #2230
1 parent a1487b0 commit 8e6ef3f

2 files changed

Lines changed: 229 additions & 11 deletions

File tree

sdk/api/handlers/openai/openai_responses_websocket.go

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
7979
var lastRequest []byte
8080
lastResponseOutput := []byte("[]")
8181
pinnedAuthID := ""
82+
forceTranscriptReplayNextRequest := false
8283

8384
for {
8485
msgType, payload, errReadMessage := conn.ReadMessage()
@@ -115,6 +116,9 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
115116
}
116117
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
117118
}
119+
if forceTranscriptReplayNextRequest {
120+
allowIncrementalInputWithPreviousResponseID = false
121+
}
118122

119123
allowCompactionReplayBypass := false
120124
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
@@ -179,7 +183,13 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
179183

180184
requestJSON = repairResponsesWebsocketToolCalls(downstreamSessionKey, requestJSON)
181185
updatedLastRequest = bytes.Clone(requestJSON)
186+
previousLastRequest := bytes.Clone(lastRequest)
187+
previousLastResponseOutput := bytes.Clone(lastResponseOutput)
188+
forcedTranscriptReplay := forceTranscriptReplayNextRequest
182189
lastRequest = updatedLastRequest
190+
if forcedTranscriptReplay {
191+
forceTranscriptReplayNextRequest = false
192+
}
183193

184194
modelName := gjson.GetBytes(requestJSON, "model").String()
185195
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
@@ -204,12 +214,19 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
204214
}
205215
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
206216

207-
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
217+
completedOutput, forwardErrMsg, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsTimelineLog, passthroughSessionID)
208218
if errForward != nil {
209219
wsTerminateErr = errForward
210220
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
211221
return
212222
}
223+
if shouldReleaseResponsesWebsocketPinnedAuth(forwardErrMsg) {
224+
pinnedAuthID = ""
225+
forceTranscriptReplayNextRequest = true
226+
lastRequest = previousLastRequest
227+
lastResponseOutput = previousLastResponseOutput
228+
continue
229+
}
213230
lastResponseOutput = completedOutput
214231
}
215232
}
@@ -810,7 +827,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
810827
errs <-chan *interfaces.ErrorMessage,
811828
wsTimelineLog *strings.Builder,
812829
sessionID string,
813-
) ([]byte, error) {
830+
) ([]byte, *interfaces.ErrorMessage, error) {
814831
completed := false
815832
completedOutput := []byte("[]")
816833
downstreamSessionKey := ""
@@ -822,7 +839,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
822839
select {
823840
case <-c.Request.Context().Done():
824841
cancel(c.Request.Context().Err())
825-
return completedOutput, c.Request.Context().Err()
842+
return completedOutput, nil, c.Request.Context().Err()
826843
case errMsg, ok := <-errs:
827844
if !ok {
828845
errs = nil
@@ -847,15 +864,15 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
847864
// errWrite,
848865
// )
849866
cancel(errMsg.Error)
850-
return completedOutput, errWrite
867+
return completedOutput, errMsg, errWrite
851868
}
852869
}
853870
if errMsg != nil {
854871
cancel(errMsg.Error)
855872
} else {
856873
cancel(nil)
857874
}
858-
return completedOutput, nil
875+
return completedOutput, errMsg, nil
859876
case chunk, ok := <-data:
860877
if !ok {
861878
if !completed {
@@ -881,13 +898,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
881898
errWrite,
882899
)
883900
cancel(errMsg.Error)
884-
return completedOutput, errWrite
901+
return completedOutput, errMsg, errWrite
885902
}
886903
cancel(errMsg.Error)
887-
return completedOutput, nil
904+
return completedOutput, errMsg, nil
888905
}
889906
cancel(nil)
890-
return completedOutput, nil
907+
return completedOutput, nil, nil
891908
}
892909

893910
payloads := websocketJSONPayloadsFromChunk(chunk)
@@ -914,13 +931,31 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
914931
errWrite,
915932
)
916933
cancel(errWrite)
917-
return completedOutput, errWrite
934+
return completedOutput, nil, errWrite
918935
}
919936
}
920937
}
921938
}
922939
}
923940

941+
func shouldReleaseResponsesWebsocketPinnedAuth(errMsg *interfaces.ErrorMessage) bool {
942+
if errMsg == nil {
943+
return false
944+
}
945+
status := errMsg.StatusCode
946+
if status <= 0 && errMsg.Error != nil {
947+
if se, ok := errMsg.Error.(interface{ StatusCode() int }); ok && se != nil {
948+
status = se.StatusCode()
949+
}
950+
}
951+
switch status {
952+
case http.StatusUnauthorized, http.StatusPaymentRequired, http.StatusForbidden, http.StatusTooManyRequests:
953+
return true
954+
default:
955+
return false
956+
}
957+
}
958+
924959
func responseCompletedOutputFromPayload(payload []byte) []byte {
925960
output := gjson.GetBytes(payload, "response.output")
926961
if output.Exists() && output.IsArray() {

sdk/api/handlers/openai/openai_responses_websocket_test.go

Lines changed: 185 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,22 @@ type websocketAuthCaptureExecutor struct {
6969
authIDs []string
7070
}
7171

72+
type websocketPinnedFailoverExecutor struct {
73+
mu sync.Mutex
74+
authIDs []string
75+
calls map[string]int
76+
payloads map[string][][]byte
77+
}
78+
79+
type websocketPinnedFailoverStatusError struct {
80+
status int
81+
msg string
82+
}
83+
84+
func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
85+
86+
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
87+
7288
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
7389

7490
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -106,6 +122,76 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
106122
return append([]string(nil), e.authIDs...)
107123
}
108124

125+
func (e *websocketPinnedFailoverExecutor) Identifier() string { return "test-provider" }
126+
127+
func (e *websocketPinnedFailoverExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
128+
return coreexecutor.Response{}, errors.New("not implemented")
129+
}
130+
131+
func (e *websocketPinnedFailoverExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
132+
authID := ""
133+
if auth != nil {
134+
authID = auth.ID
135+
}
136+
137+
e.mu.Lock()
138+
if e.calls == nil {
139+
e.calls = make(map[string]int)
140+
}
141+
if e.payloads == nil {
142+
e.payloads = make(map[string][][]byte)
143+
}
144+
e.authIDs = append(e.authIDs, authID)
145+
e.calls[authID]++
146+
call := e.calls[authID]
147+
e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload))
148+
e.mu.Unlock()
149+
150+
if authID == "auth-a" && call == 2 {
151+
chunks := make(chan coreexecutor.StreamChunk, 1)
152+
chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{
153+
status: http.StatusTooManyRequests,
154+
msg: `{"error":{"message":"quota exhausted","type":"rate_limit_error","code":"rate_limit_exceeded"}}`,
155+
}}
156+
close(chunks)
157+
return &coreexecutor.StreamResult{Chunks: chunks}, nil
158+
}
159+
160+
chunks := make(chan coreexecutor.StreamChunk, 1)
161+
chunks <- coreexecutor.StreamChunk{Payload: []byte(fmt.Sprintf(`{"type":"response.completed","response":{"id":"resp-%s-%d","output":[{"type":"message","id":"out-%s-%d"}]}}`, authID, call, authID, call))}
162+
close(chunks)
163+
return &coreexecutor.StreamResult{Chunks: chunks}, nil
164+
}
165+
166+
func (e *websocketPinnedFailoverExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
167+
return auth, nil
168+
}
169+
170+
func (e *websocketPinnedFailoverExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
171+
return coreexecutor.Response{}, errors.New("not implemented")
172+
}
173+
174+
func (e *websocketPinnedFailoverExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
175+
return nil, errors.New("not implemented")
176+
}
177+
178+
func (e *websocketPinnedFailoverExecutor) AuthIDs() []string {
179+
e.mu.Lock()
180+
defer e.mu.Unlock()
181+
return append([]string(nil), e.authIDs...)
182+
}
183+
184+
func (e *websocketPinnedFailoverExecutor) Payloads(authID string) [][]byte {
185+
e.mu.Lock()
186+
defer e.mu.Unlock()
187+
src := e.payloads[authID]
188+
out := make([][]byte, len(src))
189+
for i := range src {
190+
out[i] = bytes.Clone(src[i])
191+
}
192+
return out
193+
}
194+
109195
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
110196

111197
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
@@ -681,7 +767,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
681767
close(errCh)
682768

683769
var timelineLog strings.Builder
684-
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
770+
completedOutput, errMsg, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
685771
ctx,
686772
conn,
687773
func(...interface{}) {},
@@ -694,6 +780,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
694780
serverErrCh <- err
695781
return
696782
}
783+
if errMsg != nil {
784+
serverErrCh <- fmt.Errorf("unexpected websocket error message: %v", errMsg.Error)
785+
return
786+
}
697787
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
698788
serverErrCh <- errors.New("completed output not captured")
699789
return
@@ -760,7 +850,7 @@ func TestForwardResponsesWebsocketLogsAttemptedResponseOnWriteFailure(t *testing
760850
return
761851
}
762852

763-
_, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
853+
_, _, err = (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
764854
ctx,
765855
conn,
766856
func(...interface{}) {},
@@ -1113,6 +1203,99 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
11131203
}
11141204
}
11151205

1206+
func TestResponsesWebsocketReleasesPinnedAuthAfterQuotaError(t *testing.T) {
1207+
gin.SetMode(gin.TestMode)
1208+
1209+
selector := &orderedWebsocketSelector{order: []string{"auth-a", "auth-b"}}
1210+
executor := &websocketPinnedFailoverExecutor{}
1211+
manager := coreauth.NewManager(nil, selector, nil)
1212+
manager.RegisterExecutor(executor)
1213+
1214+
authA := &coreauth.Auth{
1215+
ID: "auth-a",
1216+
Provider: executor.Identifier(),
1217+
Status: coreauth.StatusActive,
1218+
Attributes: map[string]string{"websockets": "true"},
1219+
}
1220+
if _, err := manager.Register(context.Background(), authA); err != nil {
1221+
t.Fatalf("Register auth A: %v", err)
1222+
}
1223+
authB := &coreauth.Auth{
1224+
ID: "auth-b",
1225+
Provider: executor.Identifier(),
1226+
Status: coreauth.StatusActive,
1227+
Attributes: map[string]string{"websockets": "true"},
1228+
}
1229+
if _, err := manager.Register(context.Background(), authB); err != nil {
1230+
t.Fatalf("Register auth B: %v", err)
1231+
}
1232+
1233+
registry.GetGlobalRegistry().RegisterClient(authA.ID, authA.Provider, []*registry.ModelInfo{{ID: "quota-model"}})
1234+
registry.GetGlobalRegistry().RegisterClient(authB.ID, authB.Provider, []*registry.ModelInfo{{ID: "quota-model"}})
1235+
t.Cleanup(func() {
1236+
registry.GetGlobalRegistry().UnregisterClient(authA.ID)
1237+
registry.GetGlobalRegistry().UnregisterClient(authB.ID)
1238+
})
1239+
1240+
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
1241+
h := NewOpenAIResponsesAPIHandler(base)
1242+
router := gin.New()
1243+
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
1244+
1245+
server := httptest.NewServer(router)
1246+
defer server.Close()
1247+
1248+
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
1249+
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
1250+
if err != nil {
1251+
t.Fatalf("dial websocket: %v", err)
1252+
}
1253+
defer func() {
1254+
if errClose := conn.Close(); errClose != nil {
1255+
t.Fatalf("close websocket: %v", errClose)
1256+
}
1257+
}()
1258+
1259+
requests := []string{
1260+
`{"type":"response.create","model":"quota-model","input":[{"type":"message","id":"msg-1"}]}`,
1261+
`{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-2"}]}`,
1262+
`{"type":"response.create","previous_response_id":"resp-auth-a-1","input":[{"type":"message","id":"msg-3"}]}`,
1263+
}
1264+
wantTypes := []string{wsEventTypeCompleted, wsEventTypeError, wsEventTypeCompleted}
1265+
for i := range requests {
1266+
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
1267+
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
1268+
}
1269+
_, payload, errReadMessage := conn.ReadMessage()
1270+
if errReadMessage != nil {
1271+
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
1272+
}
1273+
if got := gjson.GetBytes(payload, "type").String(); got != wantTypes[i] {
1274+
t.Fatalf("message %d payload type = %s, want %s: %s", i+1, got, wantTypes[i], payload)
1275+
}
1276+
if i == 1 && int(gjson.GetBytes(payload, "status").Int()) != http.StatusTooManyRequests {
1277+
t.Fatalf("quota payload status = %d, want %d: %s", gjson.GetBytes(payload, "status").Int(), http.StatusTooManyRequests, payload)
1278+
}
1279+
}
1280+
1281+
if got := executor.AuthIDs(); len(got) != 3 || got[0] != "auth-a" || got[1] != "auth-a" || got[2] != "auth-b" {
1282+
t.Fatalf("selected auth IDs = %v, want [auth-a auth-a auth-b]", got)
1283+
}
1284+
1285+
authBPayloads := executor.Payloads("auth-b")
1286+
if len(authBPayloads) != 1 {
1287+
t.Fatalf("auth-b payload count = %d, want 1", len(authBPayloads))
1288+
}
1289+
authBPayload := authBPayloads[0]
1290+
if gjson.GetBytes(authBPayload, "previous_response_id").Exists() {
1291+
t.Fatalf("previous_response_id leaked after auth failover: %s", authBPayload)
1292+
}
1293+
authBInput := gjson.GetBytes(authBPayload, "input").Raw
1294+
if !strings.Contains(authBInput, `"id":"msg-1"`) || !strings.Contains(authBInput, `"id":"msg-3"`) {
1295+
t.Fatalf("auth-b replay input missing expected transcript items: %s", authBInput)
1296+
}
1297+
}
1298+
11161299
func TestNormalizeResponsesWebsocketRequestTreatsTranscriptReplacementAsReset(t *testing.T) {
11171300
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"},{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"function_call_output","id":"tool-out-1","call_id":"call-1"},{"type":"message","id":"assistant-1","role":"assistant"}]}`)
11181301
lastResponseOutput := []byte(`[

0 commit comments

Comments
 (0)