Skip to content

Commit 132ea2b

Browse files
committed
feat: add account enable scheduling toggle
1 parent ed6d300 commit 132ea2b

12 files changed

Lines changed: 303 additions & 6 deletions

File tree

admin/handler.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
110110
api.PATCH("/accounts/:id/scheduler", h.UpdateAccountScheduler)
111111
api.DELETE("/accounts/:id", h.DeleteAccount)
112112
api.POST("/accounts/:id/refresh", h.RefreshAccount)
113+
api.POST("/accounts/:id/enable", h.ToggleAccountEnabled)
113114
api.POST("/accounts/:id/lock", h.ToggleAccountLock)
114115
api.POST("/accounts/:id/reset-status", h.ResetAccountStatus)
115116
api.GET("/accounts/:id/test", h.TestConnection)
@@ -299,6 +300,7 @@ type accountResponse struct {
299300
LastRateLimitedAt string `json:"last_rate_limited_at,omitempty"`
300301
LastTimeoutAt string `json:"last_timeout_at,omitempty"`
301302
LastServerErrorAt string `json:"last_server_error_at,omitempty"`
303+
Enabled bool `json:"enabled"`
302304
Locked bool `json:"locked"`
303305
AllowedAPIKeyIDs []int64 `json:"allowed_api_key_ids"`
304306
// 图片配额信息
@@ -361,6 +363,7 @@ func (h *Handler) ListAccounts(c *gin.Context) {
361363
Status: row.Status,
362364
ATOnly: row.GetCredential("refresh_token") == "" && row.GetCredential("access_token") != "",
363365
ProxyURL: row.ProxyURL,
366+
Enabled: row.Enabled,
364367
Locked: row.Locked,
365368
AllowedAPIKeyIDs: row.GetCredentialInt64Slice("allowed_api_key_ids"),
366369
ScoreBiasOverride: nullableInt64Pointer(row.ScoreBiasOverride),
@@ -1673,6 +1676,40 @@ func (h *Handler) RefreshAccount(c *gin.Context) {
16731676
writeMessage(c, http.StatusOK, "账号刷新成功")
16741677
}
16751678

1679+
// ToggleAccountEnabled 切换账号是否参与调度选择
1680+
func (h *Handler) ToggleAccountEnabled(c *gin.Context) {
1681+
idStr := c.Param("id")
1682+
id, err := strconv.ParseInt(idStr, 10, 64)
1683+
if err != nil {
1684+
writeError(c, http.StatusBadRequest, "无效的账号 ID")
1685+
return
1686+
}
1687+
1688+
var req struct {
1689+
Enabled bool `json:"enabled"`
1690+
}
1691+
if err := c.ShouldBindJSON(&req); err != nil {
1692+
writeError(c, http.StatusBadRequest, "请求格式错误")
1693+
return
1694+
}
1695+
1696+
ctx, cancel := context.WithTimeout(c.Request.Context(), 3*time.Second)
1697+
defer cancel()
1698+
1699+
if err := h.db.SetAccountEnabled(ctx, id, req.Enabled); err != nil {
1700+
writeError(c, http.StatusInternalServerError, "更新启用状态失败: "+err.Error())
1701+
return
1702+
}
1703+
1704+
h.store.ApplyAccountEnabled(id, req.Enabled)
1705+
1706+
if req.Enabled {
1707+
writeMessage(c, http.StatusOK, "账号已启用")
1708+
} else {
1709+
writeMessage(c, http.StatusOK, "账号已禁用")
1710+
}
1711+
}
1712+
16761713
// ToggleAccountLock 切换账号的锁定状态
16771714
func (h *Handler) ToggleAccountLock(c *gin.Context) {
16781715
idStr := c.Param("id")

auth/fast_scheduler.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,9 @@ func (a *Account) fastSchedulerSnapshot(baseLimit int64, now time.Time) (Account
402402
}
403403

404404
available := a.Status != StatusError && tier != HealthTierBanned && a.AccessToken != ""
405+
if atomic.LoadInt32(&a.DispatchPaused) != 0 {
406+
available = false
407+
}
405408
if a.Status == StatusCooldown && now.Before(a.CooldownUtil) && !a.premium5hCooldownSuppressedLocked(now) {
406409
available = false
407410
}

auth/fast_scheduler_test.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package auth
22

33
import (
4+
"context"
45
"sync/atomic"
56
"testing"
67
"time"
@@ -37,6 +38,25 @@ func TestFastSchedulerAcquirePrefersHealthyTier(t *testing.T) {
3738
}
3839
}
3940

41+
func TestFastSchedulerSkipsDispatchPausedAccount(t *testing.T) {
42+
paused := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 2)
43+
atomic.StoreInt32(&paused.DispatchPaused, 1)
44+
fallback := newFastSchedulerTestAccount(2, HealthTierHealthy, 80, 2)
45+
46+
scheduler := NewFastScheduler(2)
47+
scheduler.Rebuild([]*Account{paused, fallback})
48+
49+
got := scheduler.Acquire()
50+
if got == nil {
51+
t.Fatal("Acquire() returned nil")
52+
}
53+
defer scheduler.Release(got)
54+
55+
if got.DBID != fallback.DBID {
56+
t.Fatalf("Acquire() picked dbID=%d, want %d", got.DBID, fallback.DBID)
57+
}
58+
}
59+
4060
func TestFastSchedulerRespectsConcurrencyLimit(t *testing.T) {
4161
acc := newFastSchedulerTestAccount(1, HealthTierHealthy, 100, 1)
4262

@@ -108,6 +128,94 @@ func TestStoreNextExcludingRespectsAPIKeyWhitelist(t *testing.T) {
108128
}
109129
}
110130

131+
func TestStoreNextSkipsDispatchPausedAccount(t *testing.T) {
132+
paused := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 1)
133+
atomic.StoreInt32(&paused.DispatchPaused, 1)
134+
fallback := newFastSchedulerTestAccount(2, HealthTierHealthy, 80, 1)
135+
136+
store := &Store{
137+
accounts: []*Account{paused, fallback},
138+
maxConcurrency: 1,
139+
}
140+
141+
got := store.Next()
142+
if got == nil {
143+
t.Fatal("Next() returned nil")
144+
}
145+
defer store.Release(got)
146+
147+
if got.DBID != fallback.DBID {
148+
t.Fatalf("Next() picked dbID=%d, want %d", got.DBID, fallback.DBID)
149+
}
150+
}
151+
152+
func TestDispatchPausedDoesNotBlockUsageProbe(t *testing.T) {
153+
paused := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 1)
154+
atomic.StoreInt32(&paused.DispatchPaused, 1)
155+
156+
store := &Store{
157+
accounts: []*Account{paused},
158+
}
159+
var probed int32
160+
store.SetUsageProbeFunc(func(_ context.Context, account *Account) error {
161+
if account.DBID != paused.DBID {
162+
t.Fatalf("usage probe account dbID=%d, want %d", account.DBID, paused.DBID)
163+
}
164+
atomic.AddInt32(&probed, 1)
165+
return nil
166+
})
167+
168+
store.parallelProbeUsage(context.Background())
169+
170+
if got := atomic.LoadInt32(&probed); got != 1 {
171+
t.Fatalf("usage probe calls = %d, want 1", got)
172+
}
173+
}
174+
175+
func TestDispatchPausedDoesNotBlockRecoveryProbe(t *testing.T) {
176+
paused := newFastSchedulerTestAccount(1, HealthTierBanned, 120, 1)
177+
paused.RefreshToken = "rt"
178+
paused.ExpiresAt = time.Now().Add(time.Hour)
179+
atomic.StoreInt32(&paused.DispatchPaused, 1)
180+
181+
store := &Store{
182+
accounts: []*Account{paused},
183+
}
184+
var probed int32
185+
store.SetUsageProbeFunc(func(_ context.Context, account *Account) error {
186+
if account.DBID != paused.DBID {
187+
t.Fatalf("recovery probe account dbID=%d, want %d", account.DBID, paused.DBID)
188+
}
189+
atomic.AddInt32(&probed, 1)
190+
return nil
191+
})
192+
193+
store.parallelRecoveryProbe(context.Background())
194+
195+
if got := atomic.LoadInt32(&probed); got != 1 {
196+
t.Fatalf("recovery probe calls = %d, want 1", got)
197+
}
198+
if atomic.LoadInt32(&paused.DispatchPaused) != 1 {
199+
t.Fatal("recovery probe cleared DispatchPaused; enable/disable must remain independent")
200+
}
201+
}
202+
203+
func TestDispatchPausedDoesNotBlockAutoClean(t *testing.T) {
204+
paused := newFastSchedulerTestAccount(1, HealthTierBanned, 120, 1)
205+
atomic.StoreInt32(&paused.DispatchPaused, 1)
206+
207+
store := &Store{
208+
accounts: []*Account{paused},
209+
}
210+
211+
if cleaned := store.CleanByRuntimeStatus(context.Background(), "unauthorized"); cleaned != 1 {
212+
t.Fatalf("CleanByRuntimeStatus cleaned %d accounts, want 1", cleaned)
213+
}
214+
if got := store.AccountCount(); got != 0 {
215+
t.Fatalf("AccountCount() = %d, want 0", got)
216+
}
217+
}
218+
111219
func TestStoreNextExcludingWithFilterRespectsPlanFilter(t *testing.T) {
112220
plus := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 1)
113221
plus.PlanType = "plus"

auth/store.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ type Account struct {
9797
Disabled int32 // 原子标志,1 = 立即不可调度(401 时瞬间置位,无需等锁)
9898
AddedAt int64 // 加入号池的时间(UnixNano),用于过期清理
9999
Locked int32 // 原子标志,1 = 锁定,自动清理跳过此账号
100+
DispatchPaused int32 // 原子标志,1 = 禁用调度选择,不影响刷新/探针/清理
100101

101102
// per-account 调度配置(nil = 跟随默认)
102103
ScoreBiasOverride *int64
@@ -561,6 +562,9 @@ func (a *Account) IsAvailable() bool {
561562
if atomic.LoadInt32(&a.Disabled) != 0 {
562563
return false
563564
}
565+
if atomic.LoadInt32(&a.DispatchPaused) != 0 {
566+
return false
567+
}
564568

565569
a.mu.RLock()
566570
defer a.mu.RUnlock()
@@ -1501,6 +1505,9 @@ func (s *Store) loadFromDB(ctx context.Context) error {
15011505
if row.Locked {
15021506
atomic.StoreInt32(&account.Locked, 1)
15031507
}
1508+
if !row.Enabled {
1509+
atomic.StoreInt32(&account.DispatchPaused, 1)
1510+
}
15041511
if row.Status == "error" {
15051512
account.Status = StatusError
15061513
account.ErrorMsg = row.ErrorMessage
@@ -2130,6 +2137,20 @@ func (s *Store) ApplyAccountAllowedAPIKeys(dbID int64, allowedAPIKeyIDs []int64)
21302137
return true
21312138
}
21322139

2140+
func (s *Store) ApplyAccountEnabled(dbID int64, enabled bool) bool {
2141+
acc := s.FindByID(dbID)
2142+
if acc == nil {
2143+
return false
2144+
}
2145+
if enabled {
2146+
atomic.StoreInt32(&acc.DispatchPaused, 0)
2147+
} else {
2148+
atomic.StoreInt32(&acc.DispatchPaused, 1)
2149+
}
2150+
s.fastSchedulerUpdate(acc)
2151+
return true
2152+
}
2153+
21332154
// MarkCooldown 标记账号进入冷却,并持久化到数据库
21342155
func (s *Store) MarkCooldown(acc *Account, duration time.Duration, reason string) {
21352156
if acc == nil {

database/postgres.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type AccountRow struct {
2626
CooldownReason string
2727
CooldownUntil sql.NullTime
2828
ErrorMessage string
29+
Enabled bool
2930
Locked bool
3031
ScoreBiasOverride sql.NullInt64
3132
BaseConcurrencyOverride sql.NullInt64
@@ -267,6 +268,7 @@ func (db *DB) migrate(ctx context.Context) error {
267268
268269
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS cooldown_reason VARCHAR(50) DEFAULT '';
269270
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS cooldown_until TIMESTAMPTZ NULL;
271+
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS enabled BOOLEAN DEFAULT TRUE;
270272
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS locked BOOLEAN DEFAULT FALSE;
271273
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS score_bias_override INT NULL;
272274
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS base_concurrency_override INT NULL;
@@ -1849,7 +1851,7 @@ func (db *DB) GetAccountTimeRangeUsage(ctx context.Context, since time.Time) (ma
18491851
// ListActive 获取所有未删除账号。
18501852
func (db *DB) ListActive(ctx context.Context) ([]*AccountRow, error) {
18511853
query := `
1852-
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
1854+
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(enabled, true), COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
18531855
FROM accounts
18541856
WHERE status <> 'deleted' AND COALESCE(error_message, '') <> 'deleted'
18551857
ORDER BY id
@@ -1878,6 +1880,7 @@ func (db *DB) ListActive(ctx context.Context) ([]*AccountRow, error) {
18781880
&a.CooldownReason,
18791881
&cooldownUntilRaw,
18801882
&a.ErrorMessage,
1883+
&a.Enabled,
18811884
&a.Locked,
18821885
&a.ScoreBiasOverride,
18831886
&a.BaseConcurrencyOverride,
@@ -1907,7 +1910,7 @@ func (db *DB) ListActive(ctx context.Context) ([]*AccountRow, error) {
19071910
// GetAccountByID 获取未删除账号的完整数据库行。
19081911
func (db *DB) GetAccountByID(ctx context.Context, id int64) (*AccountRow, error) {
19091912
query := `
1910-
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
1913+
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(enabled, true), COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
19111914
FROM accounts
19121915
WHERE id = $1 AND status <> 'deleted' AND COALESCE(error_message, '') <> 'deleted'
19131916
LIMIT 1
@@ -1928,6 +1931,7 @@ func (db *DB) GetAccountByID(ctx context.Context, id int64) (*AccountRow, error)
19281931
&a.CooldownReason,
19291932
&cooldownUntilRaw,
19301933
&a.ErrorMessage,
1934+
&a.Enabled,
19311935
&a.Locked,
19321936
&a.ScoreBiasOverride,
19331937
&a.BaseConcurrencyOverride,
@@ -2022,6 +2026,12 @@ func nullableInt64Value(v sql.NullInt64) interface{} {
20222026
return v.Int64
20232027
}
20242028

2029+
// SetAccountEnabled 设置账号是否参与调度选择
2030+
func (db *DB) SetAccountEnabled(ctx context.Context, id int64, enabled bool) error {
2031+
_, err := db.conn.ExecContext(ctx, `UPDATE accounts SET enabled = $1 WHERE id = $2`, enabled, id)
2032+
return err
2033+
}
2034+
20252035
// SetAccountLocked 设置账号的锁定状态
20262036
func (db *DB) SetAccountLocked(ctx context.Context, id int64, locked bool) error {
20272037
_, err := db.conn.ExecContext(ctx, `UPDATE accounts SET locked = $1 WHERE id = $2`, locked, id)

database/sqlite.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ func (db *DB) migrateSQLite(ctx context.Context) error {
267267
{"system_settings", "prompt_filter_sensitive_words", "TEXT DEFAULT ''"},
268268
{"system_settings", "prompt_filter_custom_patterns", "TEXT DEFAULT '[]'"},
269269
{"system_settings", "prompt_filter_disabled_patterns", "TEXT DEFAULT '[]'"},
270+
{"accounts", "enabled", "INTEGER DEFAULT 1"},
270271
{"accounts", "locked", "INTEGER DEFAULT 0"},
271272
{"accounts", "image_quota_remaining", "INTEGER NULL"},
272273
{"accounts", "image_quota_total", "INTEGER NULL"},

database/sqlite_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,44 @@ func TestNewSQLiteInitializesFreshDatabase(t *testing.T) {
2222
}
2323
}
2424

25+
func TestSQLiteAccountsEnabledDefaultsAndCanToggle(t *testing.T) {
26+
dbPath := filepath.Join(t.TempDir(), "codex2api.db")
27+
28+
db, err := New("sqlite", dbPath)
29+
if err != nil {
30+
t.Fatalf("New(sqlite) 返回错误: %v", err)
31+
}
32+
defer db.Close()
33+
34+
ctx := context.Background()
35+
id, err := db.InsertAccount(ctx, "test", "rt", "")
36+
if err != nil {
37+
t.Fatalf("InsertAccount 返回错误: %v", err)
38+
}
39+
40+
rows, err := db.ListActive(ctx)
41+
if err != nil {
42+
t.Fatalf("ListActive 返回错误: %v", err)
43+
}
44+
if len(rows) != 1 {
45+
t.Fatalf("ListActive 返回 %d 条,want 1", len(rows))
46+
}
47+
if !rows[0].Enabled {
48+
t.Fatal("new account Enabled = false, want true")
49+
}
50+
51+
if err := db.SetAccountEnabled(ctx, id, false); err != nil {
52+
t.Fatalf("SetAccountEnabled 返回错误: %v", err)
53+
}
54+
rows, err = db.ListActive(ctx)
55+
if err != nil {
56+
t.Fatalf("ListActive 返回错误: %v", err)
57+
}
58+
if rows[0].Enabled {
59+
t.Fatal("disabled account Enabled = true, want false")
60+
}
61+
}
62+
2563
func TestSQLiteUsageLogsHasAPIKeyColumns(t *testing.T) {
2664
dbPath := filepath.Join(t.TempDir(), "codex2api.db")
2765

frontend/src/api.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ export const api = {
143143
request<MessageResponse>(`/accounts/${id}/refresh`, { method: 'POST' }),
144144
updateAccountScheduler: (id: number, data: UpdateAccountSchedulerRequest) =>
145145
request<MessageResponse>(`/accounts/${id}/scheduler`, { method: 'PATCH', body: JSON.stringify(data) }),
146+
toggleAccountEnabled: (id: number, enabled: boolean) =>
147+
request<MessageResponse>(`/accounts/${id}/enable`, { method: 'POST', body: JSON.stringify({ enabled }) }),
146148
toggleAccountLock: (id: number, locked: boolean) =>
147149
request<MessageResponse>(`/accounts/${id}/lock`, { method: 'POST', body: JSON.stringify({ locked }) }),
148150
resetAccountStatus: (id: number) =>

0 commit comments

Comments
 (0)