|
1 | 1 | package wsrelay |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "bufio" |
4 | 5 | "context" |
| 6 | + "crypto/sha1" |
| 7 | + "encoding/base64" |
| 8 | + "fmt" |
5 | 9 | "io" |
| 10 | + "net" |
6 | 11 | "net/http" |
7 | 12 | "net/http/httptest" |
| 13 | + "net/url" |
8 | 14 | "strings" |
9 | 15 | "testing" |
10 | 16 | "time" |
11 | 17 |
|
| 18 | + "github.com/codex2api/auth" |
12 | 19 | "github.com/codex2api/proxy" |
13 | 20 | "github.com/gorilla/websocket" |
14 | 21 | "github.com/tidwall/gjson" |
@@ -226,6 +233,81 @@ func TestWebsocketResponseToHTTPClosesBodyOnContextCancel(t *testing.T) { |
226 | 233 | } |
227 | 234 | } |
228 | 235 |
|
| 236 | +func newClosedTestWebsocketConn(t *testing.T) *websocket.Conn { |
| 237 | + t.Helper() |
| 238 | + clientConn, serverConn := net.Pipe() |
| 239 | + handshakeDone := make(chan struct{}) |
| 240 | + go func() { |
| 241 | + defer close(handshakeDone) |
| 242 | + defer serverConn.Close() |
| 243 | + req, err := http.ReadRequest(bufio.NewReader(serverConn)) |
| 244 | + if err != nil { |
| 245 | + return |
| 246 | + } |
| 247 | + acceptHash := sha1.Sum([]byte(req.Header.Get("Sec-Websocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11")) |
| 248 | + _, _ = fmt.Fprintf(serverConn, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\n\r\n", base64.StdEncoding.EncodeToString(acceptHash[:])) |
| 249 | + }() |
| 250 | + |
| 251 | + wsURL, err := url.Parse("ws://example.test/responses") |
| 252 | + if err != nil { |
| 253 | + t.Fatalf("parse websocket URL: %v", err) |
| 254 | + } |
| 255 | + conn, _, err := websocket.NewClient(clientConn, wsURL, nil, 1024, 1024) |
| 256 | + if err != nil { |
| 257 | + t.Fatalf("create test websocket client: %v", err) |
| 258 | + } |
| 259 | + <-handshakeDone |
| 260 | + return conn |
| 261 | +} |
| 262 | + |
| 263 | +func TestExecuteRequestViaWebsocketSendFailureRemovesEffectiveProxyConnection(t *testing.T) { |
| 264 | + manager := NewManager() |
| 265 | + t.Cleanup(manager.Stop) |
| 266 | + |
| 267 | + account := &auth.Account{ |
| 268 | + DBID: 42, |
| 269 | + AccessToken: "token-123", |
| 270 | + ProxyURL: "http://account-proxy.test:8080", |
| 271 | + } |
| 272 | + sessionID := "session-1" |
| 273 | + wsURL, err := buildWebsocketURL(proxy.CodexBaseURL + CodexWsEndpoint) |
| 274 | + if err != nil { |
| 275 | + t.Fatalf("buildWebsocketURL: %v", err) |
| 276 | + } |
| 277 | + effectiveProxy := effectiveProxyURL(account, "") |
| 278 | + key := manager.poolKey(account.ID(), wsURL, sessionID, effectiveProxy) |
| 279 | + session := NewSession(account.ID(), manager) |
| 280 | + session.SetConnected(true) |
| 281 | + conn := &WsConnection{ |
| 282 | + conn: newClosedTestWebsocketConn(t), |
| 283 | + session: session, |
| 284 | + URL: wsURL, |
| 285 | + PoolKey: key, |
| 286 | + } |
| 287 | + conn.SetState(StateConnected) |
| 288 | + conn.Touch() |
| 289 | + manager.connections.Store(key, conn) |
| 290 | + manager.sessions.Store(key, session) |
| 291 | + manager.probeFunc = func(wc *WsConnection) bool { return true } |
| 292 | + |
| 293 | + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) |
| 294 | + defer cancel() |
| 295 | + exec := NewExecutorWithManager(manager) |
| 296 | + _, err = exec.ExecuteRequestViaWebsocket(ctx, account, []byte(`{"model":"gpt-5.4","input":"hi"}`), sessionID, "", "", nil, http.Header{}) |
| 297 | + if err == nil { |
| 298 | + t.Fatal("expected final send failure") |
| 299 | + } |
| 300 | + if _, ok := manager.connections.Load(key); ok { |
| 301 | + t.Fatal("expected failed connection keyed by effective account proxy to be removed") |
| 302 | + } |
| 303 | + if _, ok := manager.sessions.Load(key); ok { |
| 304 | + t.Fatal("expected failed session keyed by effective account proxy to be removed") |
| 305 | + } |
| 306 | + if conn.IsConnected() { |
| 307 | + t.Fatal("expected failed connection to be closed") |
| 308 | + } |
| 309 | +} |
| 310 | + |
229 | 311 | func TestSendRequestWritesResponseCreatePayloadDirectly(t *testing.T) { |
230 | 312 | received := make(chan []byte, 1) |
231 | 313 | upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} |
|
0 commit comments