@@ -77,6 +77,12 @@ type websocketPinnedFailoverExecutor struct {
7777 payloads map [string ][][]byte
7878}
7979
80+ type websocketBootstrapFallbackExecutor struct {
81+ mu sync.Mutex
82+ authIDs []string
83+ payloads map [string ][][]byte
84+ }
85+
8086type websocketPinnedFailoverStatusError struct {
8187 status int
8288 msg string
@@ -86,6 +92,70 @@ func (e websocketPinnedFailoverStatusError) Error() string { return e.msg }
8692
8793func (e websocketPinnedFailoverStatusError ) StatusCode () int { return e .status }
8894
95+ func (e * websocketBootstrapFallbackExecutor ) Identifier () string { return "test-provider" }
96+
97+ func (e * websocketBootstrapFallbackExecutor ) Execute (context.Context , * coreauth.Auth , coreexecutor.Request , coreexecutor.Options ) (coreexecutor.Response , error ) {
98+ return coreexecutor.Response {}, errors .New ("not implemented" )
99+ }
100+
101+ func (e * websocketBootstrapFallbackExecutor ) ExecuteStream (_ context.Context , auth * coreauth.Auth , req coreexecutor.Request , _ coreexecutor.Options ) (* coreexecutor.StreamResult , error ) {
102+ authID := ""
103+ if auth != nil {
104+ authID = auth .ID
105+ }
106+
107+ e .mu .Lock ()
108+ if e .payloads == nil {
109+ e .payloads = make (map [string ][][]byte )
110+ }
111+ e .authIDs = append (e .authIDs , authID )
112+ e .payloads [authID ] = append (e .payloads [authID ], bytes .Clone (req .Payload ))
113+ e .mu .Unlock ()
114+
115+ chunks := make (chan coreexecutor.StreamChunk , 1 )
116+ if authID == "auth-ws" {
117+ chunks <- coreexecutor.StreamChunk {Err : websocketPinnedFailoverStatusError {
118+ status : http .StatusServiceUnavailable ,
119+ msg : `{"error":{"message":"websocket bootstrap failed","type":"server_error","code":"ws_failed"}}` ,
120+ }}
121+ close (chunks )
122+ return & coreexecutor.StreamResult {Chunks : chunks }, nil
123+ }
124+
125+ chunks <- coreexecutor.StreamChunk {Payload : []byte (`{"type":"response.completed","response":{"id":"resp-http","output":[{"type":"message","id":"out-http"}]}}` )}
126+ close (chunks )
127+ return & coreexecutor.StreamResult {Chunks : chunks }, nil
128+ }
129+
130+ func (e * websocketBootstrapFallbackExecutor ) Refresh (_ context.Context , auth * coreauth.Auth ) (* coreauth.Auth , error ) {
131+ return auth , nil
132+ }
133+
134+ func (e * websocketBootstrapFallbackExecutor ) CountTokens (context.Context , * coreauth.Auth , coreexecutor.Request , coreexecutor.Options ) (coreexecutor.Response , error ) {
135+ return coreexecutor.Response {}, errors .New ("not implemented" )
136+ }
137+
138+ func (e * websocketBootstrapFallbackExecutor ) HttpRequest (context.Context , * coreauth.Auth , * http.Request ) (* http.Response , error ) {
139+ return nil , errors .New ("not implemented" )
140+ }
141+
142+ func (e * websocketBootstrapFallbackExecutor ) AuthIDs () []string {
143+ e .mu .Lock ()
144+ defer e .mu .Unlock ()
145+ return append ([]string (nil ), e .authIDs ... )
146+ }
147+
148+ func (e * websocketBootstrapFallbackExecutor ) Payloads (authID string ) [][]byte {
149+ e .mu .Lock ()
150+ defer e .mu .Unlock ()
151+ src := e .payloads [authID ]
152+ out := make ([][]byte , len (src ))
153+ for i := range src {
154+ out [i ] = bytes .Clone (src [i ])
155+ }
156+ return out
157+ }
158+
89159type websocketUpstreamDisconnectExecutor struct {
90160 mu sync.Mutex
91161 subscribed chan string
@@ -1340,6 +1410,87 @@ func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
13401410 }
13411411}
13421412
1413+ func TestResponsesWebsocketStripsGenerateWhenWebsocketAttemptFallsBackToHTTP (t * testing.T ) {
1414+ gin .SetMode (gin .TestMode )
1415+
1416+ selector := & orderedWebsocketSelector {order : []string {"auth-ws" , "auth-http" }}
1417+ executor := & websocketBootstrapFallbackExecutor {}
1418+ manager := coreauth .NewManager (nil , selector , nil )
1419+ manager .RegisterExecutor (executor )
1420+
1421+ authWS := & coreauth.Auth {
1422+ ID : "auth-ws" ,
1423+ Provider : executor .Identifier (),
1424+ Status : coreauth .StatusActive ,
1425+ Attributes : map [string ]string {"websockets" : "true" },
1426+ }
1427+ if _ , err := manager .Register (context .Background (), authWS ); err != nil {
1428+ t .Fatalf ("Register websocket auth: %v" , err )
1429+ }
1430+ authHTTP := & coreauth.Auth {ID : "auth-http" , Provider : executor .Identifier (), Status : coreauth .StatusActive }
1431+ if _ , err := manager .Register (context .Background (), authHTTP ); err != nil {
1432+ t .Fatalf ("Register HTTP auth: %v" , err )
1433+ }
1434+
1435+ registry .GetGlobalRegistry ().RegisterClient (authWS .ID , authWS .Provider , []* registry.ModelInfo {{ID : "test-model" }})
1436+ registry .GetGlobalRegistry ().RegisterClient (authHTTP .ID , authHTTP .Provider , []* registry.ModelInfo {{ID : "test-model" }})
1437+ t .Cleanup (func () {
1438+ registry .GetGlobalRegistry ().UnregisterClient (authWS .ID )
1439+ registry .GetGlobalRegistry ().UnregisterClient (authHTTP .ID )
1440+ })
1441+
1442+ base := handlers .NewBaseAPIHandlers (& sdkconfig.SDKConfig {}, manager )
1443+ h := NewOpenAIResponsesAPIHandler (base )
1444+ router := gin .New ()
1445+ router .GET ("/v1/responses/ws" , h .ResponsesWebsocket )
1446+
1447+ server := httptest .NewServer (router )
1448+ defer server .Close ()
1449+
1450+ wsURL := "ws" + strings .TrimPrefix (server .URL , "http" ) + "/v1/responses/ws"
1451+ conn , _ , err := websocket .DefaultDialer .Dial (wsURL , nil )
1452+ if err != nil {
1453+ t .Fatalf ("dial websocket: %v" , err )
1454+ }
1455+ defer func () {
1456+ if errClose := conn .Close (); errClose != nil {
1457+ t .Fatalf ("close websocket: %v" , errClose )
1458+ }
1459+ }()
1460+
1461+ request := `{"type":"response.create","model":"test-model","generate":false,"input":[{"type":"message","id":"msg-1"}]}`
1462+ if errWrite := conn .WriteMessage (websocket .TextMessage , []byte (request )); errWrite != nil {
1463+ t .Fatalf ("write websocket message: %v" , errWrite )
1464+ }
1465+ _ , payload , errReadMessage := conn .ReadMessage ()
1466+ if errReadMessage != nil {
1467+ t .Fatalf ("read websocket message: %v" , errReadMessage )
1468+ }
1469+ if got := gjson .GetBytes (payload , "type" ).String (); got != wsEventTypeCompleted {
1470+ t .Fatalf ("payload type = %s, want %s: %s" , got , wsEventTypeCompleted , payload )
1471+ }
1472+
1473+ if got := executor .AuthIDs (); len (got ) != 2 || got [0 ] != "auth-ws" || got [1 ] != "auth-http" {
1474+ t .Fatalf ("selected auth IDs = %v, want [auth-ws auth-http]" , got )
1475+ }
1476+
1477+ wsPayloads := executor .Payloads ("auth-ws" )
1478+ if len (wsPayloads ) != 1 {
1479+ t .Fatalf ("auth-ws payload count = %d, want 1" , len (wsPayloads ))
1480+ }
1481+ if ! gjson .GetBytes (wsPayloads [0 ], "generate" ).Exists () {
1482+ t .Fatalf ("websocket attempt payload unexpectedly stripped generate: %s" , wsPayloads [0 ])
1483+ }
1484+
1485+ httpPayloads := executor .Payloads ("auth-http" )
1486+ if len (httpPayloads ) != 1 {
1487+ t .Fatalf ("auth-http payload count = %d, want 1" , len (httpPayloads ))
1488+ }
1489+ if gjson .GetBytes (httpPayloads [0 ], "generate" ).Exists () {
1490+ t .Fatalf ("generate leaked after HTTP fallback: %s" , httpPayloads [0 ])
1491+ }
1492+ }
1493+
13431494func TestWebsocketClientAddressUsesGinClientIP (t * testing.T ) {
13441495 gin .SetMode (gin .TestMode )
13451496
0 commit comments