Skip to content

Commit cc27cf1

Browse files
fix(wsrelay): reserve pending requests before reuse
1 parent b643971 commit cc27cf1

3 files changed

Lines changed: 97 additions & 29 deletions

File tree

proxy/wsrelay/executor.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,26 +89,22 @@ func (e *Executor) ExecuteRequestViaWebsocket(
8989
headers := e.prepareWebsocketHeaders(accessToken, accountIDStr, sessionID, apiKey, deviceCfg, ginHeaders)
9090

9191
// 获取或创建连接
92-
wc, err := e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
92+
wc, pr, err := e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
9393
if err != nil {
9494
return nil, err
9595
}
9696

97-
// 创建请求
98-
pr := wc.session.AddPendingRequest(sessionID)
99-
10097
// 发送请求
10198
if err := e.sendRequest(wc, wsBody, pr.RequestID); err != nil {
10299
// 发送失败,尝试重连一次
103100
wc.session.RemovePendingRequest(pr.RequestID)
104101
e.manager.RemoveConnection(account.ID(), wsURL, sessionID, proxyOverride)
105102

106-
wc, err = e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
103+
wc, pr, err = e.manager.AcquireConnection(ctx, account, wsURL, sessionID, headers, proxyOverride)
107104
if err != nil {
108105
return nil, err
109106
}
110107

111-
pr = wc.session.AddPendingRequest(sessionID)
112108
if err := e.sendRequest(wc, wsBody, pr.RequestID); err != nil {
113109
wc.session.RemovePendingRequest(pr.RequestID)
114110
e.manager.ReleaseConnection(wc)

proxy/wsrelay/manager.go

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,9 @@ type Manager struct {
175175

176176
// 读写锁保护回调设置
177177
mu sync.RWMutex
178+
179+
// pool key 级别串行化,避免同一逻辑 session 在 acquire 阶段竞争同一条连接
180+
keyLocks sync.Map
178181
}
179182

180183
// NewManager 创建连接池管理器
@@ -283,6 +286,17 @@ func (m *Manager) getOnConnected() func(accountID int64, session *Session) {
283286
return m.onConnected
284287
}
285288

289+
func (m *Manager) keyLock(key string) *sync.Mutex {
290+
if v, ok := m.keyLocks.Load(key); ok {
291+
return v.(*sync.Mutex)
292+
}
293+
mu := &sync.Mutex{}
294+
if actual, loaded := m.keyLocks.LoadOrStore(key, mu); loaded {
295+
return actual.(*sync.Mutex)
296+
}
297+
return mu
298+
}
299+
286300
// AcquireConnection 获取或创建连接
287301
// 仅在同一逻辑 session 且连接空闲时复用,避免不同会话共用一条已握手连接。
288302
func (m *Manager) AcquireConnection(
@@ -292,34 +306,52 @@ func (m *Manager) AcquireConnection(
292306
sessionKey string,
293307
headers http.Header,
294308
proxyOverride string,
295-
) (*WsConnection, error) {
309+
) (*WsConnection, *PendingRequest, error) {
296310
key := m.poolKey(account.ID(), wsURL, sessionKey, effectiveProxyURL(account, proxyOverride))
311+
lock := m.keyLock(key)
312+
wait := 10 * time.Millisecond
297313

298-
if v, ok := m.connections.Load(key); ok {
299-
wc := v.(*WsConnection)
300-
if canReuseConnection(wc) {
301-
wc.Touch()
302-
return wc, nil
314+
for {
315+
lock.Lock()
316+
if v, ok := m.connections.Load(key); ok {
317+
wc := v.(*WsConnection)
318+
if canReuseConnection(wc) {
319+
pr := wc.session.AddPendingRequest(sessionKey)
320+
wc.Touch()
321+
lock.Unlock()
322+
return wc, pr, nil
323+
}
324+
if wc.IsConnected() && !wc.IsExpired() && wc.session != nil {
325+
lock.Unlock()
326+
select {
327+
case <-ctx.Done():
328+
return nil, nil, ctx.Err()
329+
case <-time.After(wait):
330+
}
331+
continue
332+
}
333+
m.connections.Delete(key)
334+
m.sessions.Delete(key)
335+
wc.Close()
303336
}
304-
m.connections.Delete(key)
305-
wc.Close()
306-
}
307337

308-
// 始终创建新连接,避免多个请求复用同一个 websocket.Conn 导致并发读取
309-
wc, err := m.createConnection(ctx, account, wsURL, sessionKey, headers, proxyOverride)
310-
if err != nil {
311-
return nil, err
312-
}
338+
wc, err := m.createConnection(ctx, account, wsURL, sessionKey, headers, proxyOverride)
339+
if err != nil {
340+
lock.Unlock()
341+
return nil, nil, err
342+
}
313343

314-
// 存储新连接
315-
m.connections.Store(key, wc)
344+
// 存储新连接并立即占位 pending request,避免返回后才记账产生竞态
345+
m.connections.Store(key, wc)
346+
pr := wc.session.AddPendingRequest(sessionKey)
347+
lock.Unlock()
316348

317-
// 调用连接回调
318-
if fn := m.getOnConnected(); fn != nil {
319-
fn(account.ID(), wc.session)
320-
}
349+
if fn := m.getOnConnected(); fn != nil {
350+
fn(account.ID(), wc.session)
351+
}
321352

322-
return wc, nil
353+
return wc, pr, nil
354+
}
323355
}
324356

325357
func canReuseConnection(wc *WsConnection) bool {
@@ -459,7 +491,7 @@ func (m *Manager) ReplaceConnection(
459491
sessionKey string,
460492
headers http.Header,
461493
proxyOverride string,
462-
) (*WsConnection, error) {
494+
) (*WsConnection, *PendingRequest, error) {
463495
// 先移除旧连接
464496
m.RemoveConnection(account.ID(), wsURL, sessionKey, effectiveProxyURL(account, proxyOverride))
465497

proxy/wsrelay/manager_test.go

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,20 @@ func TestAcquireConnectionReusesIdleConnectedConnection(t *testing.T) {
3030
manager.connections.Store(key, conn)
3131
manager.sessions.Store(key, session)
3232

33-
got, err := manager.AcquireConnection(context.Background(), account, wsURL, "session-1", http.Header{}, "")
33+
got, pr, err := manager.AcquireConnection(context.Background(), account, wsURL, "session-1", http.Header{}, "")
3434
if err != nil {
3535
t.Fatalf("AcquireConnection() error = %v", err)
3636
}
3737
if got != conn {
3838
t.Fatal("expected existing connection to be reused")
3939
}
40+
if pr == nil {
41+
t.Fatal("expected pending request reservation")
42+
}
43+
if session.PendingCount() != 1 {
44+
t.Fatalf("PendingCount = %d, want %d", session.PendingCount(), 1)
45+
}
46+
session.RemovePendingRequest(pr.RequestID)
4047
}
4148

4249
func TestPoolKeyIncludesSessionKey(t *testing.T) {
@@ -115,3 +122,36 @@ func TestCanReuseConnection(t *testing.T) {
115122
}
116123
})
117124
}
125+
126+
func TestAcquireConnectionWaitsWhileSessionHasPendingRequest(t *testing.T) {
127+
manager := NewManager()
128+
t.Cleanup(manager.Stop)
129+
130+
account := &auth.Account{DBID: 42}
131+
wsURL := "wss://example.test/responses"
132+
key := manager.poolKey(account.ID(), wsURL, "session-1", "")
133+
134+
session := NewSession(account.ID(), manager)
135+
session.SetConnected(true)
136+
blocking := session.AddPendingRequest("session-1")
137+
t.Cleanup(func() { session.RemovePendingRequest(blocking.RequestID) })
138+
139+
conn := &WsConnection{
140+
session: session,
141+
URL: wsURL,
142+
PoolKey: key,
143+
httpResp: &http.Response{StatusCode: http.StatusSwitchingProtocols},
144+
}
145+
conn.SetState(StateConnected)
146+
conn.Touch()
147+
manager.connections.Store(key, conn)
148+
manager.sessions.Store(key, session)
149+
150+
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
151+
defer cancel()
152+
153+
_, _, err := manager.AcquireConnection(ctx, account, wsURL, "session-1", http.Header{}, "")
154+
if err == nil {
155+
t.Fatal("expected acquire to stop when session stays busy until context timeout")
156+
}
157+
}

0 commit comments

Comments
 (0)