Skip to content

Commit 8ea79aa

Browse files
committed
fix: restore prompt cache and reuse connections in websocket mode
Since v2.2.7, requests without an explicit session identifier got a random stateless- connection ID per request on the WS upstream path, which leaked into both the prompt_cache_key body field and the Session_id/Conversation_id handshake headers. The upstream prompt cache therefore never hit in WS mode (HTTP mode was unaffected), and sustained load opened a new WS connection per request until the upstream throttled handshakes (bad handshake -> 503 at ~200 RPM). - Inject a deterministic prompt cache key (derived from the downstream API key, falling back to the account) when the body has none, matching HTTP-path behavior; stateless IDs no longer overwrite it - Send the deterministic key in WS handshake session headers; the stateless ID is only used for local connection pool isolation - Reuse WS connections for sessionless requests via per-(account, cache key) slots (8), falling back to one-shot connections when all slots are busy, eliminating per-request handshakes under high RPM Verified against live upstream: cache hits ~86% of input tokens from the second request on (previously 0%), and 200 RPM yields 200/200 success with zero handshake failures (previously 72x 503).
1 parent 52d30a9 commit 8ea79aa

6 files changed

Lines changed: 247 additions & 20 deletions

File tree

proxy/executor.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,16 @@ func ExecuteRequest(ctx context.Context, account *auth.Account, requestBody []by
260260
if wantWebsocket {
261261
sessionID = strings.TrimSpace(sessionID)
262262
if sessionID == "" {
263+
// stateless 连接 ID 仅用于 WS 连接池隔离,保证同一 API Key 的并发请求
264+
// 不挤在一条连接上排队。prompt cache key 必须保持确定性:若请求体没有
265+
// 显式 prompt_cache_key,这里注入与 HTTP 路径同源的确定性 key,否则
266+
// 上游 prompt cache 每次请求都会 miss(v2.2.7 引入的回归)。
263267
sessionID = statelessWebsocketSessionID()
268+
if strings.TrimSpace(gjson.GetBytes(requestBody, "prompt_cache_key").String()) == "" {
269+
if cacheKey := deterministicPromptCacheKey(apiKey, account); cacheKey != "" {
270+
requestBody, _ = sjson.SetBytes(requestBody, "prompt_cache_key", cacheKey)
271+
}
272+
}
264273
}
265274
}
266275
if wantWebsocket && WebsocketExecuteFunc != nil {
@@ -688,8 +697,31 @@ func ResolveExplicitSessionID(headers http.Header, body []byte) string {
688697
return ""
689698
}
690699

700+
const statelessWebsocketSessionPrefix = "stateless-"
701+
691702
func statelessWebsocketSessionID() string {
692-
return "stateless-" + uuid.NewString()
703+
return statelessWebsocketSessionPrefix + uuid.NewString()
704+
}
705+
706+
// IsStatelessWebsocketSessionID 判断是否为 WS 路径生成的一次性连接 ID。
707+
// 这类 ID 只用于连接池隔离,不能当作 prompt cache key 发往上游。
708+
func IsStatelessWebsocketSessionID(sessionID string) bool {
709+
return strings.HasPrefix(sessionID, statelessWebsocketSessionPrefix)
710+
}
711+
712+
// deterministicPromptCacheKey 生成与 ResolveSessionID 兜底逻辑同源的确定性
713+
// prompt cache key:优先按下游 API Key 派生,无 API Key 时按账号派生。
714+
func deterministicPromptCacheKey(apiKey string, account *auth.Account) string {
715+
apiKey = strings.TrimSpace(apiKey)
716+
if apiKey != "" {
717+
return uuid.NewSHA1(uuid.NameSpaceOID, []byte("codex2api:prompt-cache:"+apiKey)).String()
718+
}
719+
if account != nil {
720+
if id := account.ID(); id > 0 {
721+
return uuid.NewSHA1(uuid.NameSpaceOID, []byte(fmt.Sprintf("codex2api:prompt-cache:auth:%d", id))).String()
722+
}
723+
}
724+
return ""
693725
}
694726

695727
// ReadSSEStream 从上游 SSE 响应读取事件流

proxy/executor_test.go

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010

1111
"github.com/codex2api/auth"
12+
"github.com/tidwall/gjson"
1213
)
1314

1415
func TestReadSSEStream_MergesMultilineData(t *testing.T) {
@@ -517,22 +518,38 @@ func TestExecuteRequestForcedWebsocketUsesStatelessSessionWhenMissing(t *testing
517518

518519
previousWS := WebsocketExecuteFunc
519520
t.Cleanup(func() { WebsocketExecuteFunc = previousWS })
520-
var gotSessionID string
521+
var gotSessionIDs []string
522+
var gotCacheKeys []string
521523
WebsocketExecuteFunc = func(ctx context.Context, account *auth.Account, requestBody []byte, sessionID string, proxyOverride string, apiKey string, deviceCfg *DeviceProfileConfig, headers http.Header) (*http.Response, error) {
522-
gotSessionID = sessionID
524+
gotSessionIDs = append(gotSessionIDs, sessionID)
525+
gotCacheKeys = append(gotCacheKeys, gjson.GetBytes(requestBody, "prompt_cache_key").String())
523526
return &http.Response{
524527
StatusCode: http.StatusOK,
525528
Header: make(http.Header),
526529
Body: io.NopCloser(strings.NewReader(`{"id":"resp_test"}`)),
527530
}, nil
528531
}
529532

530-
resp, err := ExecuteRequest(context.Background(), &auth.Account{DBID: 1, AccessToken: "token"}, []byte(`{"model":"gpt-5.4"}`), "", "", "sk-local", nil, http.Header{})
531-
if err != nil {
532-
t.Fatalf("ExecuteRequest() error = %v", err)
533+
for i := 0; i < 2; i++ {
534+
resp, err := ExecuteRequest(context.Background(), &auth.Account{DBID: 1, AccessToken: "token"}, []byte(`{"model":"gpt-5.4"}`), "", "", "sk-local", nil, http.Header{})
535+
if err != nil {
536+
t.Fatalf("ExecuteRequest() error = %v", err)
537+
}
538+
resp.Body.Close()
539+
}
540+
for _, sessionID := range gotSessionIDs {
541+
if !strings.HasPrefix(sessionID, "stateless-") {
542+
t.Fatalf("sessionID = %q, want stateless-*", sessionID)
543+
}
544+
}
545+
if gotSessionIDs[0] == gotSessionIDs[1] {
546+
t.Fatalf("stateless sessionIDs should differ per request, both = %q", gotSessionIDs[0])
547+
}
548+
// prompt cache key 必须是确定性的:两次请求一致,且不等于一次性连接 ID
549+
if gotCacheKeys[0] == "" || gotCacheKeys[0] != gotCacheKeys[1] {
550+
t.Fatalf("prompt_cache_key = %q / %q, want identical deterministic key", gotCacheKeys[0], gotCacheKeys[1])
533551
}
534-
defer resp.Body.Close()
535-
if !strings.HasPrefix(gotSessionID, "stateless-") {
536-
t.Fatalf("sessionID = %q, want stateless-*", gotSessionID)
552+
if strings.HasPrefix(gotCacheKeys[0], "stateless-") {
553+
t.Fatalf("prompt_cache_key = %q, must not be a stateless connection ID", gotCacheKeys[0])
537554
}
538555
}

proxy/wsrelay/executor.go

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ func (e *Executor) ExecuteRequestViaWebsocket(
8888
// 准备请求体
8989
wsBody := e.prepareWebsocketBody(requestBody, sessionID)
9090

91+
// 握手头中的 Session_id/Conversation_id 会影响上游 prompt cache 路由,必须与
92+
// 请求体的确定性 prompt_cache_key 一致;stateless 连接 ID 是每请求随机的,
93+
// 发给上游会导致 prompt cache 永远 miss,它只用于本地连接池隔离。
94+
headerSessionID := sessionID
95+
if proxy.IsStatelessWebsocketSessionID(sessionID) {
96+
if cacheKey := strings.TrimSpace(gjson.GetBytes(wsBody, "prompt_cache_key").String()); cacheKey != "" {
97+
headerSessionID = cacheKey
98+
}
99+
}
100+
91101
// 构建 WebSocket URL
92102
httpURL := proxy.CodexBaseURL + CodexWsEndpoint
93103
wsURL, err := buildWebsocketURL(httpURL)
@@ -101,24 +111,33 @@ func (e *Executor) ExecuteRequestViaWebsocket(
101111
}
102112

103113
// 准备请求头
104-
headers := e.prepareWebsocketHeaders(accessToken, accountIDStr, sessionID, apiKey, deviceCfg, ginHeaders)
114+
headers := e.prepareWebsocketHeaders(accessToken, accountIDStr, headerSessionID, apiKey, deviceCfg, ginHeaders)
105115

106116
// Resin 反代:注入账号身份头
107117
if proxy.IsResinEnabled() {
108118
headers.Set("X-Resin-Account", proxy.ResinAccountID(account))
109119
}
110120

111-
// 获取或创建连接
112-
wc, pr, err := e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
113-
if err != nil {
114-
return nil, err
121+
// 获取或创建连接。无显式会话的请求(stateless 连接 ID)在确定性 cache key
122+
// 的槽位池内复用连接,避免持续高 RPM 下逐请求握手触发上游限流。
123+
poolSessionID := sessionID
124+
var wc *WsConnection
125+
var pr *PendingRequest
126+
var err2 error
127+
if proxy.IsStatelessWebsocketSessionID(sessionID) && headerSessionID != sessionID {
128+
wc, pr, poolSessionID, err2 = e.manager.AcquireReusableConnection(ctx, account, wsURL, headerSessionID, sessionID, StatelessConnectionSlots, headers, proxyOverride)
129+
} else {
130+
wc, pr, err2 = e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
131+
}
132+
if err2 != nil {
133+
return nil, err2
115134
}
116135

117136
// 发送请求,失败时最多重试 2 次(重建连接)
118137
sendErr := e.sendRequest(wc, wsBody, pr.RequestID)
119138
for retries := 0; sendErr != nil && retries < 2; retries++ {
120139
wc.session.RemovePendingRequest(pr.RequestID)
121-
e.manager.RemoveConnection(account.ID(), wsURL, sessionID, proxyOverride)
140+
e.manager.RemoveConnection(account.ID(), wsURL, poolSessionID, proxyOverride)
122141

123142
// 短暂退避,避免瞬间重连风暴
124143
select {
@@ -127,9 +146,9 @@ func (e *Executor) ExecuteRequestViaWebsocket(
127146
case <-time.After(time.Duration(retries+1) * 200 * time.Millisecond):
128147
}
129148

130-
wc, pr, err = e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
131-
if err != nil {
132-
return nil, err
149+
wc, pr, err2 = e.manager.AcquireConnection(ctx, account, wsURL, poolSessionID, headers, proxyOverride)
150+
if err2 != nil {
151+
return nil, err2
133152
}
134153
sendErr = e.sendRequest(wc, wsBody, pr.RequestID)
135154
}
@@ -145,7 +164,7 @@ func (e *Executor) ExecuteRequestViaWebsocket(
145164
return &WsResponse{
146165
conn: wc,
147166
pendingReq: pr,
148-
sessionID: sessionID,
167+
sessionID: poolSessionID,
149168
manager: e.manager,
150169
readErrChan: make(chan error, 1),
151170
}, nil
@@ -171,8 +190,11 @@ func (e *Executor) prepareWebsocketBody(body []byte, sessionID string) []byte {
171190
wsBody, _ = sjson.DeleteBytes(wsBody, "disable_response_storage")
172191

173192
// 3. 注入 prompt_cache_key
193+
// stateless sessionID 只是连接池隔离用的一次性随机 ID,注入它会让上游
194+
// prompt cache 每次请求都 miss;此时保留请求体中已有的确定性 cache key
195+
//(由 proxy.ExecuteRequest 注入或客户端自带)。
174196
existingCacheKey := strings.TrimSpace(gjson.GetBytes(wsBody, "prompt_cache_key").String())
175-
if sessionID != "" {
197+
if sessionID != "" && !proxy.IsStatelessWebsocketSessionID(sessionID) {
176198
wsBody, _ = sjson.SetBytes(wsBody, "prompt_cache_key", sessionID)
177199
} else if existingCacheKey != "" {
178200
wsBody, _ = sjson.SetBytes(wsBody, "prompt_cache_key", existingCacheKey)

proxy/wsrelay/executor_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ func TestPrepareWebsocketBodyPreservesPreviousResponseID(t *testing.T) {
113113
}
114114
}
115115

116+
func TestPrepareWebsocketBodyKeepsCacheKeyForStatelessSession(t *testing.T) {
117+
exec := NewExecutor()
118+
119+
got := exec.prepareWebsocketBody([]byte(`{"model":"gpt-5.4","prompt_cache_key":"deterministic-key","input":[]}`), "stateless-abc123")
120+
121+
if cacheKey := gjson.GetBytes(got, "prompt_cache_key").String(); cacheKey != "deterministic-key" {
122+
t.Fatalf("prompt_cache_key = %q, want deterministic-key (stateless sessionID must not overwrite); body=%s", cacheKey, got)
123+
}
124+
}
125+
126+
func TestPrepareWebsocketBodyStatelessSessionWithoutCacheKey(t *testing.T) {
127+
exec := NewExecutor()
128+
129+
got := exec.prepareWebsocketBody([]byte(`{"model":"gpt-5.4","input":[]}`), "stateless-abc123")
130+
131+
if cacheKey := gjson.GetBytes(got, "prompt_cache_key").String(); cacheKey != "" {
132+
t.Fatalf("prompt_cache_key = %q, want empty (stateless sessionID must not be injected); body=%s", cacheKey, got)
133+
}
134+
}
135+
116136
func TestNormalizeWebsocketHandshakeResponse(t *testing.T) {
117137
t.Run("switching protocols is successful websocket handshake", func(t *testing.T) {
118138
statusCode, _, failed := normalizeWebsocketHandshakeResponse(&http.Response{

proxy/wsrelay/manager.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,79 @@ func (m *Manager) AcquireConnection(
391391
}
392392
}
393393

394+
// StatelessConnectionSlots 无显式会话的请求在每个 (account, cacheKey) 维度下
395+
// 复用的持久连接槽位数。槽位内空闲连接直接复用,避免每个请求都重新握手——
396+
// 持续高 RPM 下逐请求握手会触发上游 WS 握手限流(bad handshake → 503)。
397+
const StatelessConnectionSlots = 8
398+
399+
// AcquireReusableConnection 在固定槽位内复用或创建连接,返回实际使用的 session key。
400+
// 第一遍只复用已存在且空闲的连接;第二遍在空槽位新建持久连接;槽位全忙时回退到
401+
// fallbackKey 的一次性连接,保持与原 stateless 行为一致的并发上限(无上限)。
402+
func (m *Manager) AcquireReusableConnection(
403+
ctx context.Context,
404+
account *auth.Account,
405+
wsURL string,
406+
baseKey string,
407+
fallbackKey string,
408+
slots int,
409+
headers http.Header,
410+
proxyOverride string,
411+
) (*WsConnection, *PendingRequest, string, error) {
412+
proxyURL := effectiveProxyURL(account, proxyOverride)
413+
// 第一遍:复用空闲连接(探活失败或已断开的顺手清理,让第二遍可以补位)
414+
for i := 0; i < slots; i++ {
415+
slotSession := fmt.Sprintf("%s#%d", baseKey, i)
416+
key := m.poolKey(account.ID(), wsURL, slotSession, proxyURL)
417+
lock := m.keyLock(key)
418+
lock.Lock()
419+
if v, ok := m.connections.Load(key); ok {
420+
wc := v.(*WsConnection)
421+
if canReuseConnection(wc) {
422+
if m.probe(wc) {
423+
pr := wc.session.AddPendingRequest(slotSession)
424+
wc.Touch()
425+
lock.Unlock()
426+
return wc, pr, slotSession, nil
427+
}
428+
m.connections.Delete(key)
429+
m.sessions.Delete(key)
430+
wc.Close()
431+
} else if !wc.IsConnected() || wc.IsExpired() {
432+
m.connections.Delete(key)
433+
m.sessions.Delete(key)
434+
wc.Close()
435+
}
436+
}
437+
lock.Unlock()
438+
}
439+
// 第二遍:在空槽位新建持久连接
440+
for i := 0; i < slots; i++ {
441+
slotSession := fmt.Sprintf("%s#%d", baseKey, i)
442+
key := m.poolKey(account.ID(), wsURL, slotSession, proxyURL)
443+
lock := m.keyLock(key)
444+
lock.Lock()
445+
if _, ok := m.connections.Load(key); ok {
446+
lock.Unlock()
447+
continue
448+
}
449+
wc, err := m.createConnection(ctx, account, wsURL, slotSession, headers, proxyOverride)
450+
if err != nil {
451+
lock.Unlock()
452+
return nil, nil, "", err
453+
}
454+
m.connections.Store(key, wc)
455+
pr := wc.session.AddPendingRequest(slotSession)
456+
lock.Unlock()
457+
if fn := m.getOnConnected(); fn != nil {
458+
fn(account.ID(), wc.session)
459+
}
460+
return wc, pr, slotSession, nil
461+
}
462+
// 槽位全忙:回退一次性连接
463+
wc, pr, err := m.AcquireConnection(ctx, account, wsURL, fallbackKey, headers, proxyOverride)
464+
return wc, pr, fallbackKey, err
465+
}
466+
394467
func canReuseConnection(wc *WsConnection) bool {
395468
if wc == nil {
396469
return false

proxy/wsrelay/manager_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,66 @@ func TestAcquireConnectionWaitsWhileSessionHasPendingRequest(t *testing.T) {
158158
t.Fatal("expected acquire to stop when session stays busy until context timeout")
159159
}
160160
}
161+
162+
func newTestSlotConnection(manager *Manager, account *auth.Account, wsURL, slotSession string) (*WsConnection, *Session) {
163+
key := manager.poolKey(account.ID(), wsURL, slotSession, "")
164+
session := NewSession(account.ID(), manager)
165+
session.SetConnected(true)
166+
conn := &WsConnection{
167+
session: session,
168+
URL: wsURL,
169+
PoolKey: key,
170+
httpResp: &http.Response{StatusCode: http.StatusSwitchingProtocols},
171+
}
172+
conn.SetState(StateConnected)
173+
conn.Touch()
174+
manager.connections.Store(key, conn)
175+
manager.sessions.Store(key, session)
176+
return conn, session
177+
}
178+
179+
func TestAcquireReusableConnectionReusesIdleSlot(t *testing.T) {
180+
manager := NewManager()
181+
t.Cleanup(manager.Stop)
182+
manager.probeFunc = func(wc *WsConnection) bool { return true }
183+
184+
account := &auth.Account{DBID: 42}
185+
wsURL := "wss://example.test/responses"
186+
conn, _ := newTestSlotConnection(manager, account, wsURL, "cache-key#0")
187+
188+
got, pr, usedKey, err := manager.AcquireReusableConnection(context.Background(), account, wsURL, "cache-key", "stateless-xyz", 4, http.Header{}, "")
189+
if err != nil {
190+
t.Fatalf("AcquireReusableConnection() error = %v", err)
191+
}
192+
if got != conn {
193+
t.Fatal("expected idle slot connection to be reused")
194+
}
195+
if usedKey != "cache-key#0" {
196+
t.Fatalf("usedKey = %q, want cache-key#0", usedKey)
197+
}
198+
got.session.RemovePendingRequest(pr.RequestID)
199+
}
200+
201+
func TestAcquireReusableConnectionSkipsBusySlot(t *testing.T) {
202+
manager := NewManager()
203+
t.Cleanup(manager.Stop)
204+
manager.probeFunc = func(wc *WsConnection) bool { return true }
205+
206+
account := &auth.Account{DBID: 42}
207+
wsURL := "wss://example.test/responses"
208+
_, busySession := newTestSlotConnection(manager, account, wsURL, "cache-key#0")
209+
busySession.AddPendingRequest("cache-key#0") // 占用 slot 0
210+
idleConn, _ := newTestSlotConnection(manager, account, wsURL, "cache-key#1")
211+
212+
got, pr, usedKey, err := manager.AcquireReusableConnection(context.Background(), account, wsURL, "cache-key", "stateless-xyz", 4, http.Header{}, "")
213+
if err != nil {
214+
t.Fatalf("AcquireReusableConnection() error = %v", err)
215+
}
216+
if got != idleConn {
217+
t.Fatal("expected busy slot 0 to be skipped and idle slot 1 reused")
218+
}
219+
if usedKey != "cache-key#1" {
220+
t.Fatalf("usedKey = %q, want cache-key#1", usedKey)
221+
}
222+
got.session.RemovePendingRequest(pr.RequestID)
223+
}

0 commit comments

Comments
 (0)