Skip to content

Commit a104a4e

Browse files
feat(proxy): add session-aware account affinity
1 parent a62927b commit a104a4e

4 files changed

Lines changed: 176 additions & 7 deletions

File tree

auth/session_affinity_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package auth
2+
3+
import "testing"
4+
5+
func TestNextForSessionPrefersBoundAccountAndProxy(t *testing.T) {
6+
store := &Store{
7+
accounts: []*Account{
8+
{DBID: 1, AccessToken: "tok-1"},
9+
{DBID: 2, AccessToken: "tok-2"},
10+
},
11+
maxConcurrency: 2,
12+
}
13+
store.bindSessionAffinity("session-1", store.accounts[1], "http://proxy-2")
14+
15+
acc, proxyURL := store.NextForSession("session-1", nil)
16+
if acc == nil {
17+
t.Fatal("expected account")
18+
}
19+
if acc.DBID != 2 {
20+
t.Fatalf("account DBID = %d, want %d", acc.DBID, 2)
21+
}
22+
if proxyURL != "http://proxy-2" {
23+
t.Fatalf("proxyURL = %q, want %q", proxyURL, "http://proxy-2")
24+
}
25+
}
26+
27+
func TestNextForSessionFallsBackWhenBoundAccountExcluded(t *testing.T) {
28+
store := &Store{
29+
accounts: []*Account{
30+
{DBID: 1, AccessToken: "tok-1"},
31+
{DBID: 2, AccessToken: "tok-2"},
32+
},
33+
maxConcurrency: 2,
34+
}
35+
store.bindSessionAffinity("session-1", store.accounts[1], "http://proxy-2")
36+
37+
acc, proxyURL := store.NextForSession("session-1", map[int64]bool{2: true})
38+
if acc == nil {
39+
t.Fatal("expected fallback account")
40+
}
41+
if acc.DBID != 1 {
42+
t.Fatalf("account DBID = %d, want %d", acc.DBID, 1)
43+
}
44+
if proxyURL != "" {
45+
t.Fatalf("proxyURL = %q, want empty fallback proxy", proxyURL)
46+
}
47+
}

auth/store.go

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,10 +726,20 @@ type Store struct {
726726
// 智能刷新调度器
727727
refreshScheduler atomic.Pointer[RefreshSchedulerIntegration]
728728

729-
allowRemoteMigration atomic.Bool // 是否允许远程迁移拉取账号
729+
allowRemoteMigration atomic.Bool // 是否允许远程迁移拉取账号
730730
modelMapping atomic.Value // 模型映射 JSON 字符串
731+
sessionMu sync.RWMutex
732+
sessionBindings map[string]sessionAffinity
731733
}
732734

735+
type sessionAffinity struct {
736+
accountID int64
737+
proxyURL string
738+
expiresAt time.Time
739+
}
740+
741+
const sessionAffinityTTL = 30 * time.Minute
742+
733743
func fastSchedulerEnabledFromEnv() bool {
734744
for _, key := range []string{"FAST_SCHEDULER_ENABLED", "CODEX_FAST_SCHEDULER"} {
735745
if truthyEnv(os.Getenv(key)) {
@@ -766,6 +776,7 @@ func NewStore(db *database.DB, tc cache.TokenCache, settings *database.SystemSet
766776
tokenCache: tc,
767777
stopCh: make(chan struct{}),
768778
proxyPoolEnabled: settings.ProxyPoolEnabled,
779+
sessionBindings: make(map[string]sessionAffinity),
769780
}
770781
s.testModel.Store(settings.TestModel)
771782
s.autoCleanUnauthorized.Store(settings.AutoCleanUnauthorized)
@@ -1276,6 +1287,95 @@ func (s *Store) NextExcluding(exclude map[int64]bool) *Account {
12761287
return best
12771288
}
12781289

1290+
// BindSessionAffinity 记录会话与账号/代理的亲和关系。
1291+
func (s *Store) BindSessionAffinity(key string, account *Account, proxyURL string) {
1292+
s.bindSessionAffinity(key, account, proxyURL)
1293+
}
1294+
1295+
func (s *Store) bindSessionAffinity(key string, account *Account, proxyURL string) {
1296+
if s == nil || account == nil {
1297+
return
1298+
}
1299+
key = strings.TrimSpace(key)
1300+
if key == "" {
1301+
return
1302+
}
1303+
1304+
s.sessionMu.Lock()
1305+
if s.sessionBindings == nil {
1306+
s.sessionBindings = make(map[string]sessionAffinity)
1307+
}
1308+
s.sessionBindings[key] = sessionAffinity{
1309+
accountID: account.DBID,
1310+
proxyURL: strings.TrimSpace(proxyURL),
1311+
expiresAt: time.Now().Add(sessionAffinityTTL),
1312+
}
1313+
s.sessionMu.Unlock()
1314+
}
1315+
1316+
// NextForSession 优先复用已绑定的账号和代理,失败时回退到普通选号。
1317+
func (s *Store) NextForSession(key string, exclude map[int64]bool) (*Account, string) {
1318+
if s == nil {
1319+
return nil, ""
1320+
}
1321+
key = strings.TrimSpace(key)
1322+
if key == "" {
1323+
return s.NextExcluding(exclude), ""
1324+
}
1325+
1326+
now := time.Now()
1327+
s.sessionMu.RLock()
1328+
binding, ok := s.sessionBindings[key]
1329+
s.sessionMu.RUnlock()
1330+
1331+
if ok {
1332+
if !binding.expiresAt.After(now) {
1333+
s.sessionMu.Lock()
1334+
if current, exists := s.sessionBindings[key]; exists && !current.expiresAt.After(now) {
1335+
delete(s.sessionBindings, key)
1336+
}
1337+
s.sessionMu.Unlock()
1338+
} else if acc := s.takeByIDExcluding(binding.accountID, exclude); acc != nil {
1339+
return acc, binding.proxyURL
1340+
}
1341+
}
1342+
1343+
return s.NextExcluding(exclude), ""
1344+
}
1345+
1346+
func (s *Store) takeByIDExcluding(id int64, exclude map[int64]bool) *Account {
1347+
if s == nil || id == 0 {
1348+
return nil
1349+
}
1350+
if exclude != nil && exclude[id] {
1351+
return nil
1352+
}
1353+
1354+
s.mu.RLock()
1355+
var target *Account
1356+
for _, acc := range s.accounts {
1357+
if acc != nil && acc.DBID == id {
1358+
target = acc
1359+
break
1360+
}
1361+
}
1362+
s.mu.RUnlock()
1363+
if target == nil || !target.IsAvailable() {
1364+
return nil
1365+
}
1366+
1367+
maxConcurrency := atomic.LoadInt64(&s.maxConcurrency)
1368+
now := time.Now()
1369+
_, _, limit, available := target.fastSchedulerSnapshot(maxConcurrency, now)
1370+
if !available || limit <= 0 {
1371+
return nil
1372+
}
1373+
if !tryAcquireAccount(target, limit) {
1374+
return nil
1375+
}
1376+
return target
1377+
}
1378+
12791379
// WaitForAvailable 等待可用账号(带超时的请求排队)
12801380
func (s *Store) WaitForAvailable(ctx context.Context, timeout time.Duration) *Account {
12811381
deadline := time.NewTimer(timeout)

proxy/handler.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@ type Handler struct {
3636
dbKeysUntil time.Time
3737
}
3838

39+
func (h *Handler) nextAccountForSession(sessionID string, exclude map[int64]bool) (*auth.Account, string) {
40+
if h == nil || h.store == nil {
41+
return nil, ""
42+
}
43+
return h.store.NextForSession(sessionID, exclude)
44+
}
45+
3946
type usageLimitDetails struct {
4047
message string
4148
planType string
@@ -444,7 +451,7 @@ func (h *Handler) Responses(c *gin.Context) {
444451
excludeAccounts := make(map[int64]bool) // 重试时排除已失败的账号
445452

446453
for attempt := 0; attempt <= maxRetries; attempt++ {
447-
account := h.store.NextExcluding(excludeAccounts)
454+
account, stickyProxyURL := h.nextAccountForSession(sessionID, excludeAccounts)
448455
if account == nil {
449456
// 排队等待可用账号(最多 30s)
450457
account = h.store.WaitForAvailable(c.Request.Context(), 30*time.Second)
@@ -461,7 +468,10 @@ func (h *Handler) Responses(c *gin.Context) {
461468
}
462469

463470
start := time.Now()
464-
proxyURL := h.store.NextProxy()
471+
proxyURL := stickyProxyURL
472+
if proxyURL == "" {
473+
proxyURL = h.store.NextProxy()
474+
}
465475
useWebsocket := h.cfg != nil && h.cfg.UseWebsocket
466476

467477
// 提取 API Key 用于设备指纹稳定化
@@ -669,6 +679,8 @@ func (h *Handler) Responses(c *gin.Context) {
669679
}
670680
continue
671681
}
682+
683+
h.store.BindSessionAffinity(sessionID, account, proxyURL)
672684
logStatusCode := outcome.logStatusCode
673685
if outcome.logStatusCode != http.StatusOK {
674686
log.Printf("流异常结束 (account %d, /v1/responses, status %d): %s,已转发约 %d 字符", account.ID(), outcome.logStatusCode, outcome.failureMessage, deltaCharCount)
@@ -808,7 +820,7 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
808820
excludeAccounts := make(map[int64]bool) // 重试时排除已失败的账号
809821

810822
for attempt := 0; attempt <= maxRetries; attempt++ {
811-
account := h.store.NextExcluding(excludeAccounts)
823+
account, stickyProxyURL := h.nextAccountForSession(sessionID, excludeAccounts)
812824
if account == nil {
813825
// 排队等待可用账号(最多 30s)
814826
account = h.store.WaitForAvailable(c.Request.Context(), 30*time.Second)
@@ -825,7 +837,10 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
825837
}
826838

827839
start := time.Now()
828-
proxyURL := h.store.NextProxy()
840+
proxyURL := stickyProxyURL
841+
if proxyURL == "" {
842+
proxyURL = h.store.NextProxy()
843+
}
829844
useWebsocket := h.cfg != nil && h.cfg.UseWebsocket
830845

831846
// 提取 API Key 用于设备指纹稳定化
@@ -1038,6 +1053,8 @@ func (h *Handler) ChatCompletions(c *gin.Context) {
10381053
}
10391054
continue
10401055
}
1056+
1057+
h.store.BindSessionAffinity(sessionID, account, proxyURL)
10411058
logStatusCode := outcome.logStatusCode
10421059
if outcome.logStatusCode != http.StatusOK {
10431060
log.Printf("流异常结束 (account %d, /v1/chat/completions, status %d): %s,已转发约 %d 字符", account.ID(), outcome.logStatusCode, outcome.failureMessage, deltaCharCount)

proxy/handler_anthropic.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ func (h *Handler) Messages(c *gin.Context) {
119119
excludeAccounts := make(map[int64]bool)
120120

121121
for attempt := 0; attempt <= maxRetries; attempt++ {
122-
account := h.store.NextExcluding(excludeAccounts)
122+
account, stickyProxyURL := h.nextAccountForSession(sessionID, excludeAccounts)
123123
if account == nil {
124124
account = h.store.WaitForAvailable(c.Request.Context(), 30*time.Second)
125125
if account == nil {
@@ -133,7 +133,10 @@ func (h *Handler) Messages(c *gin.Context) {
133133
}
134134

135135
start := time.Now()
136-
proxyURL := h.store.NextProxy()
136+
proxyURL := stickyProxyURL
137+
if proxyURL == "" {
138+
proxyURL = h.store.NextProxy()
139+
}
137140
useWebsocket := h.cfg != nil && h.cfg.UseWebsocket
138141

139142
apiKey := strings.TrimPrefix(c.GetHeader("Authorization"), "Bearer ")
@@ -362,6 +365,8 @@ func (h *Handler) Messages(c *gin.Context) {
362365
continue
363366
}
364367

368+
h.store.BindSessionAffinity(sessionID, account, proxyURL)
369+
365370
logStatusCode := outcome.logStatusCode
366371
if outcome.logStatusCode != http.StatusOK {
367372
log.Printf("流异常结束 (account %d, /v1/messages, status %d): %s,已转发约 %d 字符",

0 commit comments

Comments
 (0)