Skip to content

Commit 2eb06ec

Browse files
authored
Merge pull request #247 from huangye123/worktree-fix-ws-relay-stability-pr2
fix: 修复 WebSocket relay 稳定性问题
2 parents ae63848 + ec2a21b commit 2eb06ec

7 files changed

Lines changed: 279 additions & 7 deletions

File tree

proxy/wsrelay/error_passthrough_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,59 @@ func TestBuildErrorEvent_UpstreamErrorBecomesFailedEvent(t *testing.T) {
2626
}
2727
}
2828

29+
// TestBuildErrorEvent_UpstreamStatusBecomesResponseStatusCode verifies an upstream
30+
// top-level status is copied to response.status_code even when the preserved
31+
// upstream error object does not contain status_code itself.
32+
func TestBuildErrorEvent_UpstreamStatusBecomesResponseStatusCode(t *testing.T) {
33+
r := &WsResponse{}
34+
upstream := []byte(`{"type":"error","status":429,"error":{"type":"rate_limit_error","message":"too many requests","param":"input"}}`)
35+
36+
event, isErr := r.buildErrorEvent(upstream)
37+
if !isErr {
38+
t.Fatal("expected error frame to be detected")
39+
}
40+
if statusCode := gjson.GetBytes(event, "response.status_code").Int(); statusCode != 429 {
41+
t.Fatalf("response.status_code = %d, want 429; event=%s", statusCode, event)
42+
}
43+
if gjson.GetBytes(event, "response.error.status_code").Exists() {
44+
t.Fatalf("response.error should preserve upstream error object without injecting status_code: %s", event)
45+
}
46+
if msg := gjson.GetBytes(event, "response.error.message").String(); msg != "too many requests" {
47+
t.Fatalf("response.error.message = %q, want preserved upstream message", msg)
48+
}
49+
if typ := gjson.GetBytes(event, "response.error.type").String(); typ != "rate_limit_error" {
50+
t.Fatalf("response.error.type = %q, want preserved upstream type", typ)
51+
}
52+
if param := gjson.GetBytes(event, "response.error.param").String(); param != "input" {
53+
t.Fatalf("response.error.param = %q, want preserved upstream param", param)
54+
}
55+
}
56+
57+
// TestBuildErrorEvent_UpstreamStatusCodeBecomesResponseStatusCode verifies an
58+
// upstream top-level status_code is copied to response.status_code while the
59+
// original response.error payload remains unchanged.
60+
func TestBuildErrorEvent_UpstreamStatusCodeBecomesResponseStatusCode(t *testing.T) {
61+
r := &WsResponse{}
62+
upstream := []byte(`{"type":"error","status_code":503,"error":{"message":"service unavailable","code":"server_error"}}`)
63+
64+
event, isErr := r.buildErrorEvent(upstream)
65+
if !isErr {
66+
t.Fatal("expected error frame to be detected")
67+
}
68+
if statusCode := gjson.GetBytes(event, "response.status_code").Int(); statusCode != 503 {
69+
t.Fatalf("response.status_code = %d, want 503; event=%s", statusCode, event)
70+
}
71+
if gjson.GetBytes(event, "response.error.status_code").Exists() {
72+
t.Fatalf("response.error should preserve upstream error object without injecting status_code: %s", event)
73+
}
74+
if msg := gjson.GetBytes(event, "response.error.message").String(); msg != "service unavailable" {
75+
t.Fatalf("response.error.message = %q, want preserved upstream message", msg)
76+
}
77+
if code := gjson.GetBytes(event, "response.error.code").String(); code != "server_error" {
78+
t.Fatalf("response.error.code = %q, want preserved upstream code", code)
79+
}
80+
}
81+
2982
// TestBuildErrorEvent_NonErrorPassthrough 验证非错误帧不被识别为错误。
3083
func TestBuildErrorEvent_NonErrorPassthrough(t *testing.T) {
3184
r := &WsResponse{}

proxy/wsrelay/executor.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ func (e *Executor) ExecuteRequestViaWebsocket(
121121
// 获取或创建连接。无显式会话的请求(stateless 连接 ID)在确定性 cache key
122122
// 的槽位池内复用连接,避免持续高 RPM 下逐请求握手触发上游限流。
123123
poolSessionID := sessionID
124+
effectiveProxy := effectiveProxyURL(account, proxyOverride)
124125
var wc *WsConnection
125126
var pr *PendingRequest
126127
var err2 error
@@ -137,7 +138,7 @@ func (e *Executor) ExecuteRequestViaWebsocket(
137138
sendErr := e.sendRequest(wc, wsBody, pr.RequestID)
138139
for retries := 0; sendErr != nil && retries < 2; retries++ {
139140
wc.session.RemovePendingRequest(pr.RequestID)
140-
e.manager.RemoveConnection(account.ID(), wsURL, poolSessionID, proxyOverride)
141+
e.manager.RemoveConnection(account.ID(), wsURL, poolSessionID, effectiveProxy)
141142

142143
// 短暂退避,避免瞬间重连风暴
143144
select {
@@ -154,7 +155,7 @@ func (e *Executor) ExecuteRequestViaWebsocket(
154155
}
155156
if sendErr != nil {
156157
wc.session.RemovePendingRequest(pr.RequestID)
157-
e.manager.ReleaseConnection(wc)
158+
e.manager.RemoveConnection(account.ID(), wsURL, poolSessionID, effectiveProxy)
158159
return nil, fmt.Errorf("发送 WebSocket 请求失败: %w", sendErr)
159160
}
160161

@@ -393,6 +394,9 @@ func (r *WsResponse) buildErrorEvent(payload []byte) ([]byte, bool) {
393394
errObj = fmt.Sprintf(`{"message":%q,"code":%d}`, errMsg, status)
394395
}
395396
event := fmt.Sprintf(`{"type":"response.failed","response":{"status":"failed","error":%s}}`, errObj)
397+
if status > 0 {
398+
event = fmt.Sprintf(`{"type":"response.failed","response":{"status":"failed","status_code":%d,"error":%s}}`, status, errObj)
399+
}
396400
return []byte(event), true
397401
}
398402

proxy/wsrelay/executor_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
package wsrelay
22

33
import (
4+
"bufio"
45
"context"
6+
"crypto/sha1"
7+
"encoding/base64"
8+
"fmt"
59
"io"
10+
"net"
611
"net/http"
712
"net/http/httptest"
13+
"net/url"
814
"strings"
915
"testing"
1016
"time"
1117

18+
"github.com/codex2api/auth"
1219
"github.com/codex2api/proxy"
1320
"github.com/gorilla/websocket"
1421
"github.com/tidwall/gjson"
@@ -226,6 +233,81 @@ func TestWebsocketResponseToHTTPClosesBodyOnContextCancel(t *testing.T) {
226233
}
227234
}
228235

236+
func newClosedTestWebsocketConn(t *testing.T) *websocket.Conn {
237+
t.Helper()
238+
clientConn, serverConn := net.Pipe()
239+
handshakeDone := make(chan struct{})
240+
go func() {
241+
defer close(handshakeDone)
242+
defer serverConn.Close()
243+
req, err := http.ReadRequest(bufio.NewReader(serverConn))
244+
if err != nil {
245+
return
246+
}
247+
acceptHash := sha1.Sum([]byte(req.Header.Get("Sec-Websocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"))
248+
_, _ = fmt.Fprintf(serverConn, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: %s\r\n\r\n", base64.StdEncoding.EncodeToString(acceptHash[:]))
249+
}()
250+
251+
wsURL, err := url.Parse("ws://example.test/responses")
252+
if err != nil {
253+
t.Fatalf("parse websocket URL: %v", err)
254+
}
255+
conn, _, err := websocket.NewClient(clientConn, wsURL, nil, 1024, 1024)
256+
if err != nil {
257+
t.Fatalf("create test websocket client: %v", err)
258+
}
259+
<-handshakeDone
260+
return conn
261+
}
262+
263+
func TestExecuteRequestViaWebsocketSendFailureRemovesEffectiveProxyConnection(t *testing.T) {
264+
manager := NewManager()
265+
t.Cleanup(manager.Stop)
266+
267+
account := &auth.Account{
268+
DBID: 42,
269+
AccessToken: "token-123",
270+
ProxyURL: "http://account-proxy.test:8080",
271+
}
272+
sessionID := "session-1"
273+
wsURL, err := buildWebsocketURL(proxy.CodexBaseURL + CodexWsEndpoint)
274+
if err != nil {
275+
t.Fatalf("buildWebsocketURL: %v", err)
276+
}
277+
effectiveProxy := effectiveProxyURL(account, "")
278+
key := manager.poolKey(account.ID(), wsURL, sessionID, effectiveProxy)
279+
session := NewSession(account.ID(), manager)
280+
session.SetConnected(true)
281+
conn := &WsConnection{
282+
conn: newClosedTestWebsocketConn(t),
283+
session: session,
284+
URL: wsURL,
285+
PoolKey: key,
286+
}
287+
conn.SetState(StateConnected)
288+
conn.Touch()
289+
manager.connections.Store(key, conn)
290+
manager.sessions.Store(key, session)
291+
manager.probeFunc = func(wc *WsConnection) bool { return true }
292+
293+
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
294+
defer cancel()
295+
exec := NewExecutorWithManager(manager)
296+
_, err = exec.ExecuteRequestViaWebsocket(ctx, account, []byte(`{"model":"gpt-5.4","input":"hi"}`), sessionID, "", "", nil, http.Header{})
297+
if err == nil {
298+
t.Fatal("expected final send failure")
299+
}
300+
if _, ok := manager.connections.Load(key); ok {
301+
t.Fatal("expected failed connection keyed by effective account proxy to be removed")
302+
}
303+
if _, ok := manager.sessions.Load(key); ok {
304+
t.Fatal("expected failed session keyed by effective account proxy to be removed")
305+
}
306+
if conn.IsConnected() {
307+
t.Fatal("expected failed connection to be closed")
308+
}
309+
}
310+
229311
func TestSendRequestWritesResponseCreatePayloadDirectly(t *testing.T) {
230312
received := make(chan []byte, 1)
231313
upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}

proxy/wsrelay/manager.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ type Manager struct {
169169
// 清理定时器
170170
cleanupTicker *time.Ticker
171171
stopCleanup chan struct{}
172+
stopOnce sync.Once
172173

173174
// 连接回调
174175
onConnected func(accountID int64, session *Session)
@@ -252,8 +253,10 @@ func (m *Manager) evictExpired() {
252253

253254
// Stop 停止管理器
254255
func (m *Manager) Stop() {
255-
close(m.stopCleanup)
256-
m.closeAll()
256+
m.stopOnce.Do(func() {
257+
close(m.stopCleanup)
258+
m.closeAll()
259+
})
257260
}
258261

259262
// closeAll 关闭所有连接

proxy/wsrelay/manager_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,100 @@ package wsrelay
33
import (
44
"context"
55
"net/http"
6+
"sync"
67
"testing"
78
"time"
89

910
"github.com/codex2api/auth"
1011
)
1112

13+
func TestManagerStopIdempotent(t *testing.T) {
14+
manager := NewManager()
15+
16+
for i := 0; i < 3; i++ {
17+
func() {
18+
defer func() {
19+
if r := recover(); r != nil {
20+
t.Fatalf("Stop call %d panicked: %v", i+1, r)
21+
}
22+
}()
23+
manager.Stop()
24+
}()
25+
}
26+
}
27+
28+
func TestManagerStopConcurrent(t *testing.T) {
29+
manager := NewManager()
30+
31+
const callers = 32
32+
start := make(chan struct{})
33+
panicCh := make(chan any, callers)
34+
done := make(chan struct{})
35+
var wg sync.WaitGroup
36+
wg.Add(callers)
37+
38+
for i := 0; i < callers; i++ {
39+
go func() {
40+
defer wg.Done()
41+
defer func() {
42+
if r := recover(); r != nil {
43+
panicCh <- r
44+
}
45+
}()
46+
<-start
47+
manager.Stop()
48+
}()
49+
}
50+
51+
close(start)
52+
go func() {
53+
wg.Wait()
54+
close(done)
55+
}()
56+
57+
select {
58+
case <-done:
59+
case <-time.After(2 * time.Second):
60+
t.Fatal("concurrent Stop calls timed out")
61+
}
62+
close(panicCh)
63+
64+
for r := range panicCh {
65+
t.Fatalf("concurrent Stop panicked: %v", r)
66+
}
67+
}
68+
69+
func TestRemoveConnectionUsesEffectiveProxyKey(t *testing.T) {
70+
manager := NewManager()
71+
t.Cleanup(manager.Stop)
72+
73+
account := &auth.Account{DBID: 42, ProxyURL: " http://proxy-a.example:8080 "}
74+
wsURL := "wss://example.test/responses"
75+
sessionKey := "session-1"
76+
proxyURL := effectiveProxyURL(account, "")
77+
key := manager.poolKey(account.ID(), wsURL, sessionKey, proxyURL)
78+
79+
session := NewSession(account.ID(), manager)
80+
session.SetConnected(true)
81+
conn := &WsConnection{session: session, URL: wsURL, PoolKey: key}
82+
conn.SetState(StateConnected)
83+
conn.Touch()
84+
manager.connections.Store(key, conn)
85+
manager.sessions.Store(key, session)
86+
87+
manager.RemoveConnection(account.ID(), wsURL, sessionKey, effectiveProxyURL(account, ""))
88+
89+
if _, ok := manager.connections.Load(key); ok {
90+
t.Fatal("expected connection stored under effective proxy key to be removed")
91+
}
92+
if _, ok := manager.sessions.Load(key); ok {
93+
t.Fatal("expected session stored under effective proxy key to be removed")
94+
}
95+
if conn.IsConnected() {
96+
t.Fatal("expected removed connection to be closed")
97+
}
98+
}
99+
12100
func TestAcquireConnectionReusesIdleConnectedConnection(t *testing.T) {
13101
manager := NewManager()
14102
t.Cleanup(manager.Stop)

proxy/wsrelay/session.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package wsrelay
22

33
import (
44
"context"
5+
"log"
56
"sync"
67
"time"
78

@@ -34,7 +35,7 @@ const (
3435
// 复用同一 session 的连接时,等待其空闲的轮询退避参数。
3536
AcquireInitialBackoff = 10 * time.Millisecond // 初始退避
3637
AcquireMaxBackoff = 200 * time.Millisecond // 退避封顶
37-
AcquireMaxWait = 30 * time.Second // 最大累计等待,超时返回错误
38+
AcquireMaxWait = 30 * time.Second // 最大累计等待,超时返回错误
3839
)
3940

4041
// ==================== Pending 请求管理 ====================
@@ -244,7 +245,8 @@ func (s *Session) DeliverStreamChunk(msg *Message) bool {
244245
pr.closeMu.Unlock()
245246
return true
246247
default:
247-
// 通道已满,丢弃旧数据
248+
// 通道已满,丢弃当前流式数据块并返回 false
249+
log.Printf("DeliverStreamChunk stream channel full: account=%d session=%s requestID=%s capacity=%d length=%d; dropping current chunk", s.AccountID, s.ID, msg.RequestID, cap(pr.StreamChan), len(pr.StreamChan))
248250
pr.closeMu.Unlock()
249251
return false
250252
}

0 commit comments

Comments
 (0)