Skip to content

Commit 430e679

Browse files
committed
fix(auth): strip "generate" from payload during WebSocket HTTP fallback
- Added `sanitizeDownstreamWebsocketFallbackRequest` to clean `generate` from payload for HTTP fallback requests. - Implemented tests to validate payload handling logic in WebSocket-to-HTTP transitions. Closes: router-for-me#3556
1 parent 55901f0 commit 430e679

2 files changed

Lines changed: 166 additions & 1 deletion

File tree

sdk/api/handlers/openai/openai_responses_websocket_test.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ type websocketPinnedFailoverExecutor struct {
7777
payloads map[string][][]byte
7878
}
7979

80+
type websocketBootstrapFallbackExecutor struct {
81+
mu sync.Mutex
82+
authIDs []string
83+
payloads map[string][][]byte
84+
}
85+
8086
type websocketPinnedFailoverStatusError struct {
8187
status int
8288
msg string
@@ -86,6 +92,70 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
8692

8793
func (e websocketPinnedFailoverStatusError) StatusCode() int { return e.status }
8894

95+
func (e *websocketBootstrapFallbackExecutor) Identifier() string { return "test-provider" }
96+
97+
func (e *websocketBootstrapFallbackExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
98+
return coreexecutor.Response{}, errors.New("not implemented")
99+
}
100+
101+
func (e *websocketBootstrapFallbackExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
102+
authID := ""
103+
if auth != nil {
104+
authID = auth.ID
105+
}
106+
107+
e.mu.Lock()
108+
if e.payloads == nil {
109+
e.payloads = make(map[string][][]byte)
110+
}
111+
e.authIDs = append(e.authIDs, authID)
112+
e.payloads[authID] = append(e.payloads[authID], bytes.Clone(req.Payload))
113+
e.mu.Unlock()
114+
115+
chunks := make(chan coreexecutor.StreamChunk, 1)
116+
if authID == "auth-ws" {
117+
chunks <- coreexecutor.StreamChunk{Err: websocketPinnedFailoverStatusError{
118+
status: http.StatusServiceUnavailable,
119+
msg: `{"error":{"message":"websocket bootstrap failed","type":"server_error","code":"ws_failed"}}`,
120+
}}
121+
close(chunks)
122+
return &coreexecutor.StreamResult{Chunks: chunks}, nil
123+
}
124+
125+
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-http","output":[{"type":"message","id":"out-http"}]}}`)}
126+
close(chunks)
127+
return &coreexecutor.StreamResult{Chunks: chunks}, nil
128+
}
129+
130+
func (e *websocketBootstrapFallbackExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
131+
return auth, nil
132+
}
133+
134+
func (e *websocketBootstrapFallbackExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
135+
return coreexecutor.Response{}, errors.New("not implemented")
136+
}
137+
138+
func (e *websocketBootstrapFallbackExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
139+
return nil, errors.New("not implemented")
140+
}
141+
142+
func (e *websocketBootstrapFallbackExecutor) AuthIDs() []string {
143+
e.mu.Lock()
144+
defer e.mu.Unlock()
145+
return append([]string(nil), e.authIDs...)
146+
}
147+
148+
func (e *websocketBootstrapFallbackExecutor) Payloads(authID string) [][]byte {
149+
e.mu.Lock()
150+
defer e.mu.Unlock()
151+
src := e.payloads[authID]
152+
out := make([][]byte, len(src))
153+
for i := range src {
154+
out[i] = bytes.Clone(src[i])
155+
}
156+
return out
157+
}
158+
89159
type websocketUpstreamDisconnectExecutor struct {
90160
mu sync.Mutex
91161
subscribed chan string
@@ -1340,6 +1410,87 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
13401410
}
13411411
}
13421412

1413+
func TestResponsesWebsocketStripsGenerateWhenWebsocketAttemptFallsBackToHTTP(t *testing.T) {
1414+
gin.SetMode(gin.TestMode)
1415+
1416+
selector := &orderedWebsocketSelector{order: []string{"auth-ws", "auth-http"}}
1417+
executor := &websocketBootstrapFallbackExecutor{}
1418+
manager := coreauth.NewManager(nil, selector, nil)
1419+
manager.RegisterExecutor(executor)
1420+
1421+
authWS := &coreauth.Auth{
1422+
ID: "auth-ws",
1423+
Provider: executor.Identifier(),
1424+
Status: coreauth.StatusActive,
1425+
Attributes: map[string]string{"websockets": "true"},
1426+
}
1427+
if _, err := manager.Register(context.Background(), authWS); err != nil {
1428+
t.Fatalf("Register websocket auth: %v", err)
1429+
}
1430+
authHTTP := &coreauth.Auth{ID: "auth-http", Provider: executor.Identifier(), Status: coreauth.StatusActive}
1431+
if _, err := manager.Register(context.Background(), authHTTP); err != nil {
1432+
t.Fatalf("Register HTTP auth: %v", err)
1433+
}
1434+
1435+
registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}})
1436+
registry.GetGlobalRegistry().RegisterClient(authHTTP.ID, authHTTP.Provider, []*registry.ModelInfo{{ID: "test-model"}})
1437+
t.Cleanup(func() {
1438+
registry.GetGlobalRegistry().UnregisterClient(authWS.ID)
1439+
registry.GetGlobalRegistry().UnregisterClient(authHTTP.ID)
1440+
})
1441+
1442+
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
1443+
h := NewOpenAIResponsesAPIHandler(base)
1444+
router := gin.New()
1445+
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
1446+
1447+
server := httptest.NewServer(router)
1448+
defer server.Close()
1449+
1450+
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
1451+
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
1452+
if err != nil {
1453+
t.Fatalf("dial websocket: %v", err)
1454+
}
1455+
defer func() {
1456+
if errClose := conn.Close(); errClose != nil {
1457+
t.Fatalf("close websocket: %v", errClose)
1458+
}
1459+
}()
1460+
1461+
request := `{"type":"response.create","model":"test-model","generate":false,"input":[{"type":"message","id":"msg-1"}]}`
1462+
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(request)); errWrite != nil {
1463+
t.Fatalf("write websocket message: %v", errWrite)
1464+
}
1465+
_, payload, errReadMessage := conn.ReadMessage()
1466+
if errReadMessage != nil {
1467+
t.Fatalf("read websocket message: %v", errReadMessage)
1468+
}
1469+
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
1470+
t.Fatalf("payload type = %s, want %s: %s", got, wsEventTypeCompleted, payload)
1471+
}
1472+
1473+
if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-ws" || got[1] != "auth-http" {
1474+
t.Fatalf("selected auth IDs = %v, want [auth-ws auth-http]", got)
1475+
}
1476+
1477+
wsPayloads := executor.Payloads("auth-ws")
1478+
if len(wsPayloads) != 1 {
1479+
t.Fatalf("auth-ws payload count = %d, want 1", len(wsPayloads))
1480+
}
1481+
if !gjson.GetBytes(wsPayloads[0], "generate").Exists() {
1482+
t.Fatalf("websocket attempt payload unexpectedly stripped generate: %s", wsPayloads[0])
1483+
}
1484+
1485+
httpPayloads := executor.Payloads("auth-http")
1486+
if len(httpPayloads) != 1 {
1487+
t.Fatalf("auth-http payload count = %d, want 1", len(httpPayloads))
1488+
}
1489+
if gjson.GetBytes(httpPayloads[0], "generate").Exists() {
1490+
t.Fatalf("generate leaked after HTTP fallback: %s", httpPayloads[0])
1491+
}
1492+
}
1493+
13431494
func TestWebsocketClientAddressUsesGinClientIP(t *testing.T) {
13441495
gin.SetMode(gin.TestMode)
13451496

sdk/cliproxy/auth/conductor.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/executor"
2626
coreusage "github.com/router-for-me/CLIProxyAPI/v7/sdk/cliproxy/usage"
2727
log "github.com/sirupsen/logrus"
28+
"github.com/tidwall/sjson"
2829
)
2930

3031
// ProviderExecutor defines the contract required by Manager to execute provider calls.
@@ -1581,7 +1582,8 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
15811582
lastErr = errPrepare
15821583
continue
15831584
}
1584-
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel, models, pooled)
1585+
execReq := sanitizeDownstreamWebsocketFallbackRequest(execCtx, auth, req)
1586+
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, execReq, opts, routeModel, models, pooled)
15851587
if errStream != nil {
15861588
if errCtx := execCtx.Err(); errCtx != nil {
15871589
return nil, errCtx
@@ -1599,6 +1601,18 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
15991601
}
16001602
}
16011603

1604+
func sanitizeDownstreamWebsocketFallbackRequest(ctx context.Context, auth *Auth, req cliproxyexecutor.Request) cliproxyexecutor.Request {
1605+
if !cliproxyexecutor.DownstreamWebsocket(ctx) || authWebsocketsEnabled(auth) || len(req.Payload) == 0 {
1606+
return req
1607+
}
1608+
updated, errDelete := sjson.DeleteBytes(req.Payload, "generate")
1609+
if errDelete != nil {
1610+
return req
1611+
}
1612+
req.Payload = updated
1613+
return req
1614+
}
1615+
16021616
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
16031617
requestedModel = strings.TrimSpace(requestedModel)
16041618
if requestedModel == "" {

0 commit comments

Comments
 (0)