Skip to content

Commit f353979

Browse files
committed
feat(watcher, redisqueue): add usage refresh notification support
- Introduced `NotifyUsageRefresh` in `redisqueue` to notify subscribers of usage refresh events. - Enhanced `Watcher` logic to trigger usage refresh notifications on client changes (add/update/remove). - Updated tests to validate proper broadcast of usage refresh messages to subscribers. - Added support for initial `support_refresh` payload upon subscription initialization.
1 parent 959067e commit f353979

5 files changed

Lines changed: 223 additions & 0 deletions

File tree

internal/api/redis_queue_protocol_integration_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,68 @@ func readRESPArrayOfBulkStrings(r *bufio.Reader) ([][]byte, error) {
159159
return out, nil
160160
}
161161

162+
func readTestRESPPubSubSubscribe(r *bufio.Reader) (string, int, error) {
163+
prefix, errRead := r.ReadByte()
164+
if errRead != nil {
165+
return "", 0, errRead
166+
}
167+
if prefix != '*' {
168+
return "", 0, fmt.Errorf("expected array prefix '*', got %q", prefix)
169+
}
170+
line, errLine := readTestRESPLine(r)
171+
if errLine != nil {
172+
return "", 0, errLine
173+
}
174+
count, errParse := strconv.Atoi(line)
175+
if errParse != nil {
176+
return "", 0, fmt.Errorf("invalid array length %q: %v", line, errParse)
177+
}
178+
if count != 3 {
179+
return "", 0, fmt.Errorf("subscribe ack length = %d, want 3", count)
180+
}
181+
kind, errKind := readTestRESPBulkString(r)
182+
if errKind != nil {
183+
return "", 0, errKind
184+
}
185+
if string(kind) != "subscribe" {
186+
return "", 0, fmt.Errorf("subscribe ack kind = %q", string(kind))
187+
}
188+
channel, errChannel := readTestRESPBulkString(r)
189+
if errChannel != nil {
190+
return "", 0, errChannel
191+
}
192+
prefix, errRead = r.ReadByte()
193+
if errRead != nil {
194+
return "", 0, errRead
195+
}
196+
if prefix != ':' {
197+
return "", 0, fmt.Errorf("expected integer prefix ':', got %q", prefix)
198+
}
199+
line, errLine = readTestRESPLine(r)
200+
if errLine != nil {
201+
return "", 0, errLine
202+
}
203+
subscriptions, errParse := strconv.Atoi(line)
204+
if errParse != nil {
205+
return "", 0, fmt.Errorf("invalid subscription count %q: %v", line, errParse)
206+
}
207+
return string(channel), subscriptions, nil
208+
}
209+
210+
func readTestRESPPubSubMessage(r *bufio.Reader) (string, []byte, error) {
211+
items, errItems := readRESPArrayOfBulkStrings(r)
212+
if errItems != nil {
213+
return "", nil, errItems
214+
}
215+
if len(items) != 3 {
216+
return "", nil, fmt.Errorf("pubsub message length = %d, want 3", len(items))
217+
}
218+
if string(items[0]) != "message" {
219+
return "", nil, fmt.Errorf("pubsub message kind = %q", string(items[0]))
220+
}
221+
return string(items[1]), items[2], nil
222+
}
223+
162224
func TestRedisProtocol_ManagementDisabled_RejectsConnection(t *testing.T) {
163225
t.Setenv("MANAGEMENT_PASSWORD", "")
164226
redisqueue.SetEnabled(false)
@@ -235,6 +297,68 @@ func TestRedisProtocol_HomeEnabled_DisablesConnection(t *testing.T) {
235297
}
236298
}
237299

300+
func TestRedisProtocol_SUBSCRIBE_UsageSendsSupportRefresh(t *testing.T) {
301+
const managementPassword = "test-management-password"
302+
303+
t.Setenv("MANAGEMENT_PASSWORD", managementPassword)
304+
redisqueue.SetEnabled(false)
305+
t.Cleanup(func() { redisqueue.SetEnabled(false) })
306+
307+
server := newTestServer(t)
308+
if !server.managementRoutesEnabled.Load() {
309+
t.Fatalf("expected managementRoutesEnabled to be true")
310+
}
311+
312+
addr, stop := startRedisMuxListener(t, server)
313+
t.Cleanup(stop)
314+
315+
conn, errDial := net.DialTimeout("tcp", addr, time.Second)
316+
if errDial != nil {
317+
t.Fatalf("failed to dial redis listener: %v", errDial)
318+
}
319+
t.Cleanup(func() { _ = conn.Close() })
320+
321+
reader := bufio.NewReader(conn)
322+
_ = conn.SetDeadline(time.Now().Add(5 * time.Second))
323+
324+
if errWrite := writeTestRESPCommand(conn, "AUTH", managementPassword); errWrite != nil {
325+
t.Fatalf("failed to write AUTH command: %v", errWrite)
326+
}
327+
if msg, errRead := readTestRESPSimpleString(reader); errRead != nil {
328+
t.Fatalf("failed to read AUTH response: %v", errRead)
329+
} else if msg != "OK" {
330+
t.Fatalf("unexpected AUTH response: %q", msg)
331+
}
332+
333+
if errWrite := writeTestRESPCommand(conn, "SUBSCRIBE", "usage"); errWrite != nil {
334+
t.Fatalf("failed to write SUBSCRIBE command: %v", errWrite)
335+
}
336+
channel, subscriptions, errSubscribe := readTestRESPPubSubSubscribe(reader)
337+
if errSubscribe != nil {
338+
t.Fatalf("failed to read subscribe response: %v", errSubscribe)
339+
}
340+
if channel != "usage" || subscriptions != 1 {
341+
t.Fatalf("unexpected subscribe response channel=%q subscriptions=%d", channel, subscriptions)
342+
}
343+
344+
channel, payload, errMessage := readTestRESPPubSubMessage(reader)
345+
if errMessage != nil {
346+
t.Fatalf("failed to read support refresh message: %v", errMessage)
347+
}
348+
if channel != "usage" || string(payload) != `{"support_refresh":true}` {
349+
t.Fatalf("unexpected support refresh message channel=%q payload=%q", channel, string(payload))
350+
}
351+
352+
redisqueue.Enqueue([]byte(`{"id":1}`))
353+
channel, payload, errMessage = readTestRESPPubSubMessage(reader)
354+
if errMessage != nil {
355+
t.Fatalf("failed to read usage message: %v", errMessage)
356+
}
357+
if channel != "usage" || string(payload) != `{"id":1}` {
358+
t.Fatalf("unexpected usage message channel=%q payload=%q", channel, string(payload))
359+
}
360+
}
361+
238362
func TestRedisProtocol_AUTH_And_PopContracts(t *testing.T) {
239363
const managementPassword = "test-management-password"
240364

internal/redisqueue/queue.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ const (
1010
defaultRetentionSeconds int64 = 60
1111
maxRetentionSeconds int64 = 3600
1212
usageSubscriberBuffer = 256
13+
14+
usageSupportRefreshPayload = `{"support_refresh":true}`
15+
usageRefreshPayload = `{"refresh":true}`
1316
)
1417

1518
type queueItem struct {
@@ -83,6 +86,10 @@ func SubscribeUsage() (<-chan []byte, func()) {
8386
return global.subscribeUsage()
8487
}
8588

89+
func NotifyUsageRefresh() {
90+
global.publishToSubscribers([]byte(usageRefreshPayload))
91+
}
92+
8693
func (q *queue) clear() {
8794
q.mu.Lock()
8895

@@ -137,6 +144,7 @@ func (q *queue) publishToSubscribers(payload []byte) bool {
137144

138145
func (q *queue) subscribeUsage() (<-chan []byte, func()) {
139146
subscriber := make(chan []byte, usageSubscriberBuffer)
147+
subscriber <- []byte(usageSupportRefreshPayload)
140148

141149
q.mu.Lock()
142150
if q.subscribers == nil {

internal/redisqueue/queue_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ func TestEnqueueBroadcastsToUsageSubscribersAndSkipsQueue(t *testing.T) {
1212
second, unsubscribeSecond := SubscribeUsage()
1313
defer unsubscribeSecond()
1414

15+
requireUsageSubscriberPayload(t, first, usageSupportRefreshPayload)
16+
requireUsageSubscriberPayload(t, second, usageSupportRefreshPayload)
17+
1518
Enqueue([]byte("usage-record"))
1619

1720
requireUsageSubscriberPayload(t, first, "usage-record")
@@ -37,6 +40,8 @@ func TestSetEnabledFalseClosesUsageSubscribers(t *testing.T) {
3740
subscriber, unsubscribe := SubscribeUsage()
3841
defer unsubscribe()
3942

43+
requireUsageSubscriberPayload(t, subscriber, usageSupportRefreshPayload)
44+
4045
SetEnabled(false)
4146

4247
select {
@@ -50,6 +55,24 @@ func TestSetEnabledFalseClosesUsageSubscribers(t *testing.T) {
5055
})
5156
}
5257

58+
func TestNotifyUsageRefreshBroadcastsOnlyToUsageSubscribers(t *testing.T) {
59+
withEnabledQueue(t, func() {
60+
subscriber, unsubscribe := SubscribeUsage()
61+
defer unsubscribe()
62+
63+
requireUsageSubscriberPayload(t, subscriber, usageSupportRefreshPayload)
64+
65+
NotifyUsageRefresh()
66+
requireUsageSubscriberPayload(t, subscriber, usageRefreshPayload)
67+
68+
unsubscribe()
69+
NotifyUsageRefresh()
70+
if items := PopOldest(1); len(items) != 0 {
71+
t.Fatalf("PopOldest() items = %q, want empty after refresh notification without subscribers", items)
72+
}
73+
})
74+
}
75+
5376
func requireUsageSubscriberPayload(t *testing.T, subscriber <-chan []byte, want string) {
5477
t.Helper()
5578

internal/watcher/clients.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"time"
1515

1616
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
17+
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
1718
"github.com/router-for-me/CLIProxyAPI/v7/internal/util"
1819
"github.com/router-for-me/CLIProxyAPI/v7/internal/watcher/diff"
1920
"github.com/router-for-me/CLIProxyAPI/v7/internal/watcher/synthesizer"
@@ -134,6 +135,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
134135
}
135136

136137
w.refreshAuthState(forceAuthRefresh)
138+
redisqueue.NotifyUsageRefresh()
137139

138140
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
139141
totalNewClients,
@@ -233,6 +235,7 @@ func (w *Watcher) addOrUpdateClient(path string) {
233235

234236
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
235237
w.dispatchAuthUpdates(updates)
238+
redisqueue.NotifyUsageRefresh()
236239
}
237240

238241
func (w *Watcher) removeClient(path string) {
@@ -251,6 +254,7 @@ func (w *Watcher) removeClient(path string) {
251254

252255
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
253256
w.dispatchAuthUpdates(updates)
257+
redisqueue.NotifyUsageRefresh()
254258
}
255259

256260
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {

internal/watcher/watcher_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515

1616
"github.com/fsnotify/fsnotify"
1717
"github.com/router-for-me/CLIProxyAPI/v7/internal/config"
18+
"github.com/router-for-me/CLIProxyAPI/v7/internal/redisqueue"
1819
"github.com/router-for-me/CLIProxyAPI/v7/internal/watcher/diff"
1920
"github.com/router-for-me/CLIProxyAPI/v7/internal/watcher/synthesizer"
2021
sdkAuth "github.com/router-for-me/CLIProxyAPI/v7/sdk/auth"
@@ -441,6 +442,34 @@ func TestRemoveClientRemovesHash(t *testing.T) {
441442
}
442443
}
443444

445+
func TestAuthFileClientChangesNotifyUsageSubscribersToRefresh(t *testing.T) {
446+
tmpDir := t.TempDir()
447+
authFile := filepath.Join(tmpDir, "sample.json")
448+
if err := os.WriteFile(authFile, []byte(`{"type":"demo","api_key":"k"}`), 0o644); err != nil {
449+
t.Fatalf("failed to create auth file: %v", err)
450+
}
451+
452+
redisqueue.SetEnabled(false)
453+
redisqueue.SetEnabled(true)
454+
t.Cleanup(func() { redisqueue.SetEnabled(false) })
455+
456+
subscriber, unsubscribe := redisqueue.SubscribeUsage()
457+
defer unsubscribe()
458+
requireWatcherUsagePayload(t, subscriber, `{"support_refresh":true}`)
459+
460+
w := &Watcher{
461+
authDir: tmpDir,
462+
lastAuthHashes: make(map[string]string),
463+
}
464+
w.SetConfig(&config.Config{AuthDir: tmpDir})
465+
466+
w.addOrUpdateClient(authFile)
467+
requireWatcherUsagePayload(t, subscriber, `{"refresh":true}`)
468+
469+
w.removeClient(authFile)
470+
requireWatcherUsagePayload(t, subscriber, `{"refresh":true}`)
471+
}
472+
444473
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
445474
tmpDir := t.TempDir()
446475
authFile := filepath.Join(tmpDir, "sample.json")
@@ -699,6 +728,25 @@ func TestReloadClientsHandlesNilConfig(t *testing.T) {
699728
w.reloadClients(true, nil, false)
700729
}
701730

731+
func TestReloadClientsNotifiesUsageSubscribersToRefresh(t *testing.T) {
732+
tmp := t.TempDir()
733+
redisqueue.SetEnabled(false)
734+
redisqueue.SetEnabled(true)
735+
t.Cleanup(func() { redisqueue.SetEnabled(false) })
736+
737+
subscriber, unsubscribe := redisqueue.SubscribeUsage()
738+
defer unsubscribe()
739+
requireWatcherUsagePayload(t, subscriber, `{"support_refresh":true}`)
740+
741+
w := &Watcher{
742+
authDir: tmp,
743+
config: &config.Config{AuthDir: tmp},
744+
}
745+
w.reloadClients(false, nil, false)
746+
747+
requireWatcherUsagePayload(t, subscriber, `{"refresh":true}`)
748+
}
749+
702750
func TestReloadClientsFiltersProvidersWithNilCurrentAuths(t *testing.T) {
703751
tmp := t.TempDir()
704752
w := &Watcher{
@@ -711,6 +759,22 @@ func TestReloadClientsFiltersProvidersWithNilCurrentAuths(t *testing.T) {
711759
}
712760
}
713761

762+
func requireWatcherUsagePayload(t *testing.T, subscriber <-chan []byte, want string) {
763+
t.Helper()
764+
765+
select {
766+
case got, ok := <-subscriber:
767+
if !ok {
768+
t.Fatalf("subscriber closed before receiving %q", want)
769+
}
770+
if string(got) != want {
771+
t.Fatalf("subscriber payload = %q, want %q", string(got), want)
772+
}
773+
case <-time.After(time.Second):
774+
t.Fatalf("timeout waiting for subscriber payload %q", want)
775+
}
776+
}
777+
714778
func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
715779
w := &Watcher{}
716780
queue := make(chan AuthUpdate, 1)

0 commit comments

Comments
 (0)