Skip to content

Commit 6a1d85e

Browse files
feat(wsrelay): scope codex websocket reuse by session
1 parent c1d46d2 commit 6a1d85e

3 files changed

Lines changed: 58 additions & 21 deletions

File tree

proxy/wsrelay/executor.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (e *Executor) ExecuteRequestViaWebsocket(
8989
headers := e.prepareWebsocketHeaders(accessToken, accountIDStr, sessionID, apiKey, deviceCfg, ginHeaders)
9090

9191
// 获取或创建连接
92-
wc, err := e.manager.AcquireConnection(ctx, account, wsURL, headers, proxyOverride)
92+
wc, err := e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
9393
if err != nil {
9494
return nil, err
9595
}
@@ -101,9 +101,9 @@ func (e *Executor) ExecuteRequestViaWebsocket(
101101
if err := e.sendRequest(wc, wsBody, pr.RequestID); err != nil {
102102
// 发送失败,尝试重连一次
103103
wc.session.RemovePendingRequest(pr.RequestID)
104-
e.manager.RemoveConnection(account.ID(), wsURL)
104+
e.manager.RemoveConnection(account.ID(), wsURL, sessionID)
105105

106-
wc, err = e.manager.AcquireConnection(ctx, account, wsURL, headers, proxyOverride)
106+
wc, err = e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
107107
if err != nil {
108108
return nil, err
109109
}

proxy/wsrelay/manager.go

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net"
88
"net/http"
99
"net/url"
10+
"strings"
1011
"sync"
1112
"sync/atomic"
1213
"time"
@@ -38,6 +39,9 @@ type WsConnection struct {
3839
// 连接 URL
3940
URL string
4041

42+
// 连接池键
43+
PoolKey string
44+
4145
// 连接状态
4246
state atomic.Int32
4347

@@ -267,15 +271,16 @@ func (m *Manager) getOnConnected() func(accountID int64, session *Session) {
267271
}
268272

269273
// AcquireConnection 获取或创建连接
270-
// 注意:WebSocket 连接不支持并发读取,因此始终创建新连接而非复用
274+
// 仅在同一逻辑 session 且连接空闲时复用,避免不同会话共用一条已握手连接。
271275
func (m *Manager) AcquireConnection(
272276
ctx context.Context,
273277
account *auth.Account,
274278
wsURL string,
279+
sessionKey string,
275280
headers http.Header,
276281
proxyOverride string,
277282
) (*WsConnection, error) {
278-
key := m.poolKey(account.ID(), wsURL)
283+
key := m.poolKey(account.ID(), wsURL, sessionKey)
279284

280285
if v, ok := m.connections.Load(key); ok {
281286
wc := v.(*WsConnection)
@@ -288,7 +293,7 @@ func (m *Manager) AcquireConnection(
288293
}
289294

290295
// 始终创建新连接,避免多个请求复用同一个 websocket.Conn 导致并发读取
291-
wc, err := m.createConnection(ctx, account, wsURL, headers, proxyOverride)
296+
wc, err := m.createConnection(ctx, account, wsURL, sessionKey, headers, proxyOverride)
292297
if err != nil {
293298
return nil, err
294299
}
@@ -309,6 +314,7 @@ func (m *Manager) createConnection(
309314
ctx context.Context,
310315
account *auth.Account,
311316
wsURL string,
317+
sessionKey string,
312318
headers http.Header,
313319
proxyOverride string,
314320
) (*WsConnection, error) {
@@ -338,24 +344,28 @@ func (m *Manager) createConnection(
338344
}
339345

340346
// 创建会话(先关闭旧 session 避免泄漏)
341-
sessionKey := m.poolKey(account.ID(), wsURL)
342-
if oldSessionVal, ok := m.sessions.Load(sessionKey); ok {
347+
poolKey := m.poolKey(account.ID(), wsURL, sessionKey)
348+
if oldSessionVal, ok := m.sessions.Load(poolKey); ok {
343349
oldSession := oldSessionVal.(*Session)
344350
oldSession.Close()
345351
}
346352
session := NewSession(account.ID(), m)
347-
m.sessions.Store(sessionKey, session)
353+
if trimmed := strings.TrimSpace(sessionKey); trimmed != "" {
354+
session.ID = trimmed
355+
}
356+
m.sessions.Store(poolKey, session)
348357

349358
// 拨号连接
350359
conn, resp, err := dialer.DialContext(ctx, wsURL, headers)
351360
if err != nil {
352-
m.sessions.Delete(sessionKey)
361+
m.sessions.Delete(poolKey)
353362
session.Close()
354363
return nil, fmt.Errorf("websocket handshake failed: %w", err)
355364
}
356365

357366
// 创建连接包装
358367
wc := NewWsConnection(conn, session, wsURL)
368+
wc.PoolKey = poolKey
359369
wc.httpResp = resp
360370
wc.onDisconnected = m.getOnDisconnected()
361371
session.SetConnected(true)
@@ -379,8 +389,8 @@ func (m *Manager) ReleaseConnection(wc *WsConnection) {
379389
}
380390

381391
// RemoveConnection 移除连接
382-
func (m *Manager) RemoveConnection(accountID int64, wsURL string) {
383-
key := m.poolKey(accountID, wsURL)
392+
func (m *Manager) RemoveConnection(accountID int64, wsURL string, sessionKey string) {
393+
key := m.poolKey(accountID, wsURL, sessionKey)
384394
if v, ok := m.connections.LoadAndDelete(key); ok {
385395
wc := v.(*WsConnection)
386396
wc.Close()
@@ -389,13 +399,13 @@ func (m *Manager) RemoveConnection(accountID int64, wsURL string) {
389399
}
390400

391401
// poolKey 生成连接池键
392-
func (m *Manager) poolKey(accountID int64, wsURL string) string {
393-
return fmt.Sprintf("%d|%s", accountID, wsURL)
402+
func (m *Manager) poolKey(accountID int64, wsURL string, sessionKey string) string {
403+
return fmt.Sprintf("%d|%s|%s", accountID, wsURL, strings.TrimSpace(sessionKey))
394404
}
395405

396406
// GetSession 获取会话
397-
func (m *Manager) GetSession(accountID int64, wsURL string) (*Session, bool) {
398-
if v, ok := m.sessions.Load(m.poolKey(accountID, wsURL)); ok {
407+
func (m *Manager) GetSession(accountID int64, wsURL string, sessionKey string) (*Session, bool) {
408+
if v, ok := m.sessions.Load(m.poolKey(accountID, wsURL, sessionKey)); ok {
399409
return v.(*Session), true
400410
}
401411
return nil, false
@@ -426,14 +436,15 @@ func (m *Manager) ReplaceConnection(
426436
ctx context.Context,
427437
account *auth.Account,
428438
wsURL string,
439+
sessionKey string,
429440
headers http.Header,
430441
proxyOverride string,
431442
) (*WsConnection, error) {
432443
// 先移除旧连接
433-
m.RemoveConnection(account.ID(), wsURL)
444+
m.RemoveConnection(account.ID(), wsURL, sessionKey)
434445

435446
// 创建新连接
436-
return m.AcquireConnection(ctx, account, wsURL, headers, proxyOverride)
447+
return m.AcquireConnection(ctx, account, wsURL, sessionKey, headers, proxyOverride)
437448
}
438449

439450
// SendHeartbeat 发送心跳 Ping
@@ -450,7 +461,10 @@ func (m *Manager) SendHeartbeat(wc *WsConnection) error {
450461
if err != nil {
451462
log.Printf("WebSocket Ping 失败 (account %d): %v", wc.session.AccountID, err)
452463
wc.Close()
453-
m.connections.Delete(m.poolKey(wc.session.AccountID, wc.URL))
464+
if wc.PoolKey != "" {
465+
m.connections.Delete(wc.PoolKey)
466+
m.sessions.Delete(wc.PoolKey)
467+
}
454468
return err
455469
}
456470
return nil

proxy/wsrelay/manager_test.go

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,48 @@ func TestAcquireConnectionReusesIdleConnectedConnection(t *testing.T) {
1414

1515
account := &auth.Account{DBID: 42}
1616
wsURL := "wss://example.test/responses"
17-
key := manager.poolKey(account.ID(), wsURL)
17+
key := manager.poolKey(account.ID(), wsURL, "session-1")
1818

1919
session := NewSession(account.ID(), manager)
2020
session.SetConnected(true)
2121
conn := &WsConnection{
2222
session: session,
2323
URL: wsURL,
24+
PoolKey: key,
2425
httpResp: &http.Response{StatusCode: http.StatusSwitchingProtocols},
2526
}
2627
conn.SetState(StateConnected)
2728
conn.Touch()
2829
manager.connections.Store(key, conn)
2930
manager.sessions.Store(key, session)
3031

31-
got, err := manager.AcquireConnection(context.Background(), account, wsURL, http.Header{}, "")
32+
got, err := manager.AcquireConnection(context.Background(), account, wsURL, "session-1", http.Header{}, "")
3233
if err != nil {
3334
t.Fatalf("AcquireConnection() error = %v", err)
3435
}
3536
if got != conn {
3637
t.Fatal("expected existing connection to be reused")
3738
}
3839
}
40+
41+
func TestPoolKeyIncludesSessionKey(t *testing.T) {
42+
manager := NewManager()
43+
t.Cleanup(manager.Stop)
44+
45+
keyA := manager.poolKey(42, "wss://example.test/responses", "session-a")
46+
keyB := manager.poolKey(42, "wss://example.test/responses", "session-b")
47+
if keyA == keyB {
48+
t.Fatal("expected different session keys to produce different pool keys")
49+
}
50+
}
51+
52+
func TestPoolKeyKeepsSameSessionStable(t *testing.T) {
53+
manager := NewManager()
54+
t.Cleanup(manager.Stop)
55+
56+
keyA := manager.poolKey(42, "wss://example.test/responses", "session-a")
57+
keyB := manager.poolKey(42, "wss://example.test/responses", "session-a")
58+
if keyA != keyB {
59+
t.Fatal("expected identical session keys to produce the same pool key")
60+
}
61+
}

0 commit comments

Comments
 (0)