Skip to content

Commit 8e52c40

Browse files
committed
feat(auth): deduplicate concurrent refresh token requests with singleflight
- Introduced `singleflight.Group` to prevent redundant token refresh calls across multiple auth implementations (`antigravity`, `kimi`, `xai`, `codex`). - Added tests to verify shared upstream calls during concurrent refresh requests. - Refactored token refresh logic to centralize and standardize deduplication mechanisms.
1 parent 8caf474 commit 8e52c40

8 files changed

Lines changed: 516 additions & 44 deletions

File tree

internal/auth/codex/openai_auth.go

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
1818
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
1919
log "github.com/sirupsen/logrus"
20+
"golang.org/x/sync/singleflight"
2021
)
2122

2223
// OAuth configuration constants for OpenAI Codex
@@ -34,6 +35,8 @@ type CodexAuth struct {
3435
httpClient *http.Client
3536
}
3637

38+
var codexRefreshGroup singleflight.Group
39+
3740
// NewCodexAuth creates a new CodexAuth service instance.
3841
// It initializes an HTTP client with proxy settings from the provided configuration.
3942
func NewCodexAuth(cfg *config.Config) *CodexAuth {
@@ -187,33 +190,52 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
187190
if refreshToken == "" {
188191
return nil, fmt.Errorf("refresh token is required")
189192
}
193+
if ctx == nil {
194+
ctx = context.Background()
195+
}
190196

197+
result, err, _ := codexRefreshGroup.Do(refreshToken, func() (interface{}, error) {
198+
return o.refreshTokensSingleFlight(context.WithoutCancel(ctx), refreshToken)
199+
})
200+
if err != nil {
201+
return nil, err
202+
}
203+
tokenData, ok := result.(*CodexTokenData)
204+
if !ok || tokenData == nil {
205+
return nil, fmt.Errorf("token refresh failed: invalid single-flight result")
206+
}
207+
return tokenData, nil
208+
}
209+
210+
func (o *CodexAuth) refreshTokensSingleFlight(ctx context.Context, refreshToken string) (*CodexTokenData, error) {
191211
data := url.Values{
192212
"client_id": {ClientID},
193213
"grant_type": {"refresh_token"},
194214
"refresh_token": {refreshToken},
195215
"scope": {"openid profile email"},
196216
}
197217

198-
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
199-
if err != nil {
200-
return nil, fmt.Errorf("failed to create refresh request: %w", err)
218+
req, errReq := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
219+
if errReq != nil {
220+
return nil, fmt.Errorf("failed to create refresh request: %w", errReq)
201221
}
202222

203223
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
204224
req.Header.Set("Accept", "application/json")
205225

206-
resp, err := o.httpClient.Do(req)
207-
if err != nil {
208-
return nil, fmt.Errorf("token refresh request failed: %w", err)
226+
resp, errDo := o.httpClient.Do(req)
227+
if errDo != nil {
228+
return nil, fmt.Errorf("token refresh request failed: %w", errDo)
209229
}
210230
defer func() {
211-
_ = resp.Body.Close()
231+
if errClose := resp.Body.Close(); errClose != nil {
232+
log.Errorf("token refresh response body close error: %v", errClose)
233+
}
212234
}()
213235

214-
body, err := io.ReadAll(resp.Body)
215-
if err != nil {
216-
return nil, fmt.Errorf("failed to read refresh response: %w", err)
236+
body, errRead := io.ReadAll(resp.Body)
237+
if errRead != nil {
238+
return nil, fmt.Errorf("failed to read refresh response: %w", errRead)
217239
}
218240

219241
if resp.StatusCode != http.StatusOK {
@@ -228,14 +250,14 @@ func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*Co
228250
ExpiresIn int `json:"expires_in"`
229251
}
230252

231-
if err = json.Unmarshal(body, &tokenResp); err != nil {
232-
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
253+
if errUnmarshal := json.Unmarshal(body, &tokenResp); errUnmarshal != nil {
254+
return nil, fmt.Errorf("failed to parse refresh response: %w", errUnmarshal)
233255
}
234256

235257
// Extract account ID from ID token
236-
claims, err := ParseJWTToken(tokenResp.IDToken)
237-
if err != nil {
238-
log.Warnf("Failed to parse refreshed ID token: %v", err)
258+
claims, errParseJWT := ParseJWTToken(tokenResp.IDToken)
259+
if errParseJWT != nil {
260+
log.Warnf("Failed to parse refreshed ID token: %v", errParseJWT)
239261
}
240262

241263
accountID := ""

internal/auth/codex/openai_auth_test.go

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ import (
55
"io"
66
"net/http"
77
"strings"
8+
"sync"
89
"sync/atomic"
910
"testing"
11+
"time"
1012

1113
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
14+
"golang.org/x/sync/singleflight"
1215
)
1316

1417
type roundTripFunc func(*http.Request) (*http.Response, error)
@@ -17,6 +20,10 @@ func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
1720
return f(req)
1821
}
1922

23+
func resetCodexRefreshGroupForTest() {
24+
codexRefreshGroup = singleflight.Group{}
25+
}
26+
2027
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
2128
var calls int32
2229
auth := &CodexAuth{
@@ -45,6 +52,71 @@ func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
4552
}
4653
}
4754

55+
func TestRefreshTokens_DeduplicatesConcurrentRefreshAcrossInstances(t *testing.T) {
56+
resetCodexRefreshGroupForTest()
57+
t.Cleanup(resetCodexRefreshGroupForTest)
58+
59+
var calls int32
60+
started := make(chan struct{})
61+
release := make(chan struct{})
62+
var once sync.Once
63+
64+
transport := roundTripFunc(func(req *http.Request) (*http.Response, error) {
65+
atomic.AddInt32(&calls, 1)
66+
once.Do(func() { close(started) })
67+
<-release
68+
return &http.Response{
69+
StatusCode: http.StatusOK,
70+
Body: io.NopCloser(strings.NewReader(`{
71+
"access_token":"new-access",
72+
"refresh_token":"new-refresh",
73+
"token_type":"Bearer",
74+
"expires_in":3600
75+
}`)),
76+
Header: make(http.Header),
77+
Request: req,
78+
}, nil
79+
})
80+
authA := &CodexAuth{httpClient: &http.Client{Transport: transport}}
81+
authB := &CodexAuth{httpClient: &http.Client{Transport: transport}}
82+
83+
results := make(chan *CodexTokenData, 2)
84+
errs := make(chan error, 2)
85+
runRefresh := func(auth *CodexAuth, launched chan<- struct{}) {
86+
if launched != nil {
87+
close(launched)
88+
}
89+
tokenData, errRefresh := auth.RefreshTokens(context.Background(), "shared-refresh-token")
90+
results <- tokenData
91+
errs <- errRefresh
92+
}
93+
94+
go runRefresh(authA, nil)
95+
<-started
96+
97+
secondLaunched := make(chan struct{})
98+
go runRefresh(authB, secondLaunched)
99+
<-secondLaunched
100+
time.Sleep(20 * time.Millisecond)
101+
if got := atomic.LoadInt32(&calls); got != 1 {
102+
t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
103+
}
104+
close(release)
105+
106+
for i := 0; i < 2; i++ {
107+
if errRefresh := <-errs; errRefresh != nil {
108+
t.Fatalf("expected refresh to succeed, got %v", errRefresh)
109+
}
110+
tokenData := <-results
111+
if tokenData == nil || tokenData.AccessToken != "new-access" || tokenData.RefreshToken != "new-refresh" {
112+
t.Fatalf("unexpected token data: %#v", tokenData)
113+
}
114+
}
115+
if got := atomic.LoadInt32(&calls); got != 1 {
116+
t.Fatalf("expected both refresh callers to share a single upstream call, got %d", got)
117+
}
118+
}
119+
48120
func TestNewCodexAuthWithProxyURL_OverrideDirectDisablesProxy(t *testing.T) {
49121
cfg := &config.Config{SDKConfig: config.SDKConfig{ProxyURL: "http://proxy.example.com:8080"}}
50122
auth := NewCodexAuthWithProxyURL(cfg, "direct")

internal/auth/kimi/kimi.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
1919
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
2020
log "github.com/sirupsen/logrus"
21+
"golang.org/x/sync/singleflight"
2122
)
2223

2324
const (
@@ -39,6 +40,8 @@ const (
3940
refreshThresholdSeconds = 300
4041
)
4142

43+
var kimiRefreshGroup singleflight.Group
44+
4245
// KimiAuth handles Kimi authentication flow.
4346
type KimiAuth struct {
4447
deviceClient *DeviceFlowClient
@@ -341,6 +344,28 @@ func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode st
341344

342345
// RefreshToken exchanges a refresh token for a new access token.
343346
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
347+
if strings.TrimSpace(refreshToken) == "" {
348+
return nil, fmt.Errorf("kimi: refresh token is required")
349+
}
350+
if ctx == nil {
351+
ctx = context.Background()
352+
}
353+
refreshToken = strings.TrimSpace(refreshToken)
354+
355+
result, err, _ := kimiRefreshGroup.Do(refreshToken, func() (interface{}, error) {
356+
return c.refreshTokenSingleFlight(context.WithoutCancel(ctx), refreshToken)
357+
})
358+
if err != nil {
359+
return nil, err
360+
}
361+
tokenData, ok := result.(*KimiTokenData)
362+
if !ok || tokenData == nil {
363+
return nil, fmt.Errorf("kimi: refresh token failed: invalid single-flight result")
364+
}
365+
return tokenData, nil
366+
}
367+
368+
func (c *DeviceFlowClient) refreshTokenSingleFlight(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
344369
data := url.Values{}
345370
data.Set("client_id", kimiClientID)
346371
data.Set("grant_type", "refresh_token")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package kimi
2+
3+
import (
4+
"context"
5+
"io"
6+
"net/http"
7+
"strings"
8+
"sync"
9+
"sync/atomic"
10+
"testing"
11+
"time"
12+
13+
"golang.org/x/sync/singleflight"
14+
)
15+
16+
type kimiRoundTripFunc func(*http.Request) (*http.Response, error)
17+
18+
func (f kimiRoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
19+
return f(req)
20+
}
21+
22+
func resetKimiRefreshGroupForTest() {
23+
kimiRefreshGroup = singleflight.Group{}
24+
}
25+
26+
func TestRefreshToken_DeduplicatesConcurrentRefreshAcrossInstances(t *testing.T) {
27+
resetKimiRefreshGroupForTest()
28+
t.Cleanup(resetKimiRefreshGroupForTest)
29+
30+
var calls int32
31+
started := make(chan struct{})
32+
release := make(chan struct{})
33+
var once sync.Once
34+
35+
transport := kimiRoundTripFunc(func(req *http.Request) (*http.Response, error) {
36+
atomic.AddInt32(&calls, 1)
37+
once.Do(func() { close(started) })
38+
<-release
39+
return &http.Response{
40+
StatusCode: http.StatusOK,
41+
Body: io.NopCloser(strings.NewReader(`{
42+
"access_token":"new-access",
43+
"refresh_token":"new-refresh",
44+
"token_type":"Bearer",
45+
"expires_in":3600
46+
}`)),
47+
Header: make(http.Header),
48+
Request: req,
49+
}, nil
50+
})
51+
clientA := &DeviceFlowClient{httpClient: &http.Client{Transport: transport}}
52+
clientB := &DeviceFlowClient{httpClient: &http.Client{Transport: transport}}
53+
54+
results := make(chan *KimiTokenData, 2)
55+
errs := make(chan error, 2)
56+
runRefresh := func(client *DeviceFlowClient, launched chan<- struct{}) {
57+
if launched != nil {
58+
close(launched)
59+
}
60+
tokenData, errRefresh := client.RefreshToken(context.Background(), "shared-refresh-token")
61+
results <- tokenData
62+
errs <- errRefresh
63+
}
64+
65+
go runRefresh(clientA, nil)
66+
<-started
67+
68+
secondLaunched := make(chan struct{})
69+
go runRefresh(clientB, secondLaunched)
70+
<-secondLaunched
71+
time.Sleep(20 * time.Millisecond)
72+
if got := atomic.LoadInt32(&calls); got != 1 {
73+
t.Fatalf("expected concurrent refresh to share a single upstream call, got %d", got)
74+
}
75+
close(release)
76+
77+
for i := 0; i < 2; i++ {
78+
if errRefresh := <-errs; errRefresh != nil {
79+
t.Fatalf("expected refresh to succeed, got %v", errRefresh)
80+
}
81+
tokenData := <-results
82+
if tokenData == nil || tokenData.AccessToken != "new-access" || tokenData.RefreshToken != "new-refresh" {
83+
t.Fatalf("unexpected token data: %#v", tokenData)
84+
}
85+
}
86+
if got := atomic.LoadInt32(&calls); got != 1 {
87+
t.Fatalf("expected both refresh callers to share a single upstream call, got %d", got)
88+
}
89+
}

internal/auth/xai/xai.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,16 @@ import (
1414
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
1515
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
1616
log "github.com/sirupsen/logrus"
17+
"golang.org/x/sync/singleflight"
1718
)
1819

1920
// XAIAuth performs xAI OAuth discovery, token exchange, and refresh.
2021
type XAIAuth struct {
2122
httpClient *http.Client
2223
}
2324

25+
var xaiRefreshGroup singleflight.Group
26+
2427
// NewXAIAuth creates an xAI OAuth helper using config proxy settings.
2528
func NewXAIAuth(cfg *config.Config) *XAIAuth {
2629
return NewXAIAuthWithProxyURL(cfg, "")
@@ -180,17 +183,37 @@ func (a *XAIAuth) RefreshTokens(ctx context.Context, refreshToken, tokenEndpoint
180183
if strings.TrimSpace(refreshToken) == "" {
181184
return nil, fmt.Errorf("xai token refresh: refresh token is required")
182185
}
186+
if ctx == nil {
187+
ctx = context.Background()
188+
}
189+
refreshToken = strings.TrimSpace(refreshToken)
183190
if strings.TrimSpace(tokenEndpoint) == "" {
184191
discovery, errDiscover := a.Discover(ctx)
185192
if errDiscover != nil {
186193
return nil, errDiscover
187194
}
188195
tokenEndpoint = discovery.TokenEndpoint
189196
}
197+
tokenEndpoint = strings.TrimSpace(tokenEndpoint)
198+
199+
result, err, _ := xaiRefreshGroup.Do(refreshToken, func() (interface{}, error) {
200+
return a.refreshTokensSingleFlight(context.WithoutCancel(ctx), refreshToken, tokenEndpoint)
201+
})
202+
if err != nil {
203+
return nil, err
204+
}
205+
tokenData, ok := result.(*TokenData)
206+
if !ok || tokenData == nil {
207+
return nil, fmt.Errorf("xai token refresh failed: invalid single-flight result")
208+
}
209+
return tokenData, nil
210+
}
211+
212+
func (a *XAIAuth) refreshTokensSingleFlight(ctx context.Context, refreshToken, tokenEndpoint string) (*TokenData, error) {
190213
form := url.Values{
191214
"grant_type": {"refresh_token"},
192215
"client_id": {ClientID},
193-
"refresh_token": {strings.TrimSpace(refreshToken)},
216+
"refresh_token": {refreshToken},
194217
}
195218
return a.postTokenForm(ctx, tokenEndpoint, form)
196219
}

0 commit comments

Comments
 (0)