Skip to content

Commit fc6d59b

Browse files
feat(wsrelay): reuse idle codex websocket sessions
1 parent fb21a17 commit fc6d59b

5 files changed

Lines changed: 73 additions & 9 deletions

File tree

proxy/wsrelay/executor.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ func (e *Executor) ExecuteRequestViaWebsocket(
8686
}
8787

8888
// 准备请求头
89-
headers := e.prepareWebsocketHeaders(accessToken, accountIDStr, apiKey, deviceCfg, ginHeaders)
89+
headers := e.prepareWebsocketHeaders(accessToken, accountIDStr, sessionID, apiKey, deviceCfg, ginHeaders)
9090

9191
// 获取或创建连接
9292
wc, err := e.manager.AcquireConnection(ctx, account, wsURL, headers, proxyOverride)
@@ -162,7 +162,7 @@ func (e *Executor) prepareWebsocketBody(body []byte, sessionID string) []byte {
162162
}
163163

164164
// prepareWebsocketHeaders 准备 WebSocket 请求头
165-
func (e *Executor) prepareWebsocketHeaders(accessToken, accountID, apiKey string, deviceCfg *proxy.DeviceProfileConfig, ginHeaders http.Header) http.Header {
165+
func (e *Executor) prepareWebsocketHeaders(accessToken, accountID, sessionID, apiKey string, deviceCfg *proxy.DeviceProfileConfig, ginHeaders http.Header) http.Header {
166166
headers := http.Header{}
167167

168168
// 认证头
@@ -207,6 +207,9 @@ func (e *Executor) prepareWebsocketHeaders(accessToken, accountID, apiKey string
207207
if accountID != "" {
208208
headers.Set("Chatgpt-Account-Id", accountID)
209209
}
210+
if sessionID = strings.TrimSpace(sessionID); sessionID != "" {
211+
headers.Set("Conversation_id", sessionID)
212+
}
210213

211214
return headers
212215
}

proxy/wsrelay/executor_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ func TestPrepareWebsocketHeadersUsesConfiguredDefaultsAndBetaFeatures(t *testing
2222
"Originator": []string{"custom-originator"},
2323
}
2424

25-
headers := exec.prepareWebsocketHeaders("token-123", "42", "api-key-1", cfg, ginHeaders)
25+
headers := exec.prepareWebsocketHeaders("token-123", "42", "session-123", "api-key-1", cfg, ginHeaders)
2626

2727
if got := headers.Get("Authorization"); got != "Bearer token-123" {
2828
t.Fatalf("Authorization = %q", got)
@@ -45,4 +45,7 @@ func TestPrepareWebsocketHeadersUsesConfiguredDefaultsAndBetaFeatures(t *testing
4545
if got := headers.Get("Chatgpt-Account-Id"); got != "42" {
4646
t.Fatalf("Chatgpt-Account-Id = %q", got)
4747
}
48+
if got := headers.Get("Conversation_id"); got != "session-123" {
49+
t.Fatalf("Conversation_id = %q", got)
50+
}
4851
}

proxy/wsrelay/manager.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,10 @@ func (wc *WsConnection) Close() error {
9090
if wc.onDisconnected != nil && wc.session != nil {
9191
wc.onDisconnected(wc.session.AccountID)
9292
}
93-
return wc.conn.Close()
93+
if wc.conn != nil {
94+
return wc.conn.Close()
95+
}
96+
return nil
9497
}
9598
return nil
9699
}
@@ -274,9 +277,13 @@ func (m *Manager) AcquireConnection(
274277
) (*WsConnection, error) {
275278
key := m.poolKey(account.ID(), wsURL)
276279

277-
// 清理可能存在的旧连接(避免泄漏)
278-
if v, ok := m.connections.LoadAndDelete(key); ok {
280+
if v, ok := m.connections.Load(key); ok {
279281
wc := v.(*WsConnection)
282+
if wc.IsConnected() && !wc.IsExpired() && wc.session != nil && wc.session.PendingCount() == 0 {
283+
wc.Touch()
284+
return wc, nil
285+
}
286+
m.connections.Delete(key)
280287
wc.Close()
281288
}
282289

@@ -473,4 +480,4 @@ func ShutdownManager() {
473480
if globalManager != nil {
474481
globalManager.Stop()
475482
}
476-
}
483+
}

proxy/wsrelay/manager_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package wsrelay
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"testing"
7+
8+
"github.com/codex2api/auth"
9+
)
10+
11+
func TestAcquireConnectionReusesIdleConnectedConnection(t *testing.T) {
12+
manager := NewManager()
13+
t.Cleanup(manager.Stop)
14+
15+
account := &auth.Account{DBID: 42}
16+
wsURL := "wss://example.test/responses"
17+
key := manager.poolKey(account.ID(), wsURL)
18+
19+
session := NewSession(account.ID(), manager)
20+
session.SetConnected(true)
21+
conn := &WsConnection{
22+
session: session,
23+
URL: wsURL,
24+
httpResp: &http.Response{StatusCode: http.StatusSwitchingProtocols},
25+
}
26+
conn.SetState(StateConnected)
27+
conn.Touch()
28+
manager.connections.Store(key, conn)
29+
manager.sessions.Store(key, session)
30+
31+
got, err := manager.AcquireConnection(context.Background(), account, wsURL, http.Header{}, "")
32+
if err != nil {
33+
t.Fatalf("AcquireConnection() error = %v", err)
34+
}
35+
if got != conn {
36+
t.Fatal("expected existing connection to be reused")
37+
}
38+
}

proxy/wsrelay/session.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ type PendingRequest struct {
5656
Cancel context.CancelFunc
5757

5858
// 关闭标志,防止重复关闭
59-
closed bool
59+
closed bool
6060
closeMu sync.Mutex
6161
}
6262

@@ -190,6 +190,19 @@ func (s *Session) RemovePendingRequest(requestID string) {
190190
}
191191
}
192192

193+
// PendingCount returns the number of in-flight requests bound to this session.
194+
func (s *Session) PendingCount() int {
195+
if s == nil {
196+
return 0
197+
}
198+
count := 0
199+
s.pending.Range(func(_, _ any) bool {
200+
count++
201+
return true
202+
})
203+
return count
204+
}
205+
193206
// DeliverResponse 投递响应到等待请求
194207
func (s *Session) DeliverResponse(msg *Message) bool {
195208
if pr, ok := s.GetPendingRequest(msg.RequestID); ok {
@@ -323,4 +336,4 @@ func (s *Session) ClearPendingRequests() {
323336
s.pending.Delete(key)
324337
return true
325338
})
326-
}
339+
}

0 commit comments

Comments
 (0)