Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions admin/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
api.PATCH("/accounts/:id/scheduler", h.UpdateAccountScheduler)
api.DELETE("/accounts/:id", h.DeleteAccount)
api.POST("/accounts/:id/refresh", h.RefreshAccount)
api.POST("/accounts/:id/enable", h.ToggleAccountEnabled)
api.POST("/accounts/:id/lock", h.ToggleAccountLock)
api.POST("/accounts/:id/reset-status", h.ResetAccountStatus)
api.GET("/accounts/:id/test", h.TestConnection)
Expand Down Expand Up @@ -299,6 +300,7 @@ type accountResponse struct {
LastRateLimitedAt string `json:"last_rate_limited_at,omitempty"`
LastTimeoutAt string `json:"last_timeout_at,omitempty"`
LastServerErrorAt string `json:"last_server_error_at,omitempty"`
Enabled bool `json:"enabled"`
Locked bool `json:"locked"`
AllowedAPIKeyIDs []int64 `json:"allowed_api_key_ids"`
// 图片配额信息
Expand Down Expand Up @@ -361,6 +363,7 @@ func (h *Handler) ListAccounts(c *gin.Context) {
Status: row.Status,
ATOnly: row.GetCredential("refresh_token") == "" && row.GetCredential("access_token") != "",
ProxyURL: row.ProxyURL,
Enabled: row.Enabled,
Locked: row.Locked,
AllowedAPIKeyIDs: row.GetCredentialInt64Slice("allowed_api_key_ids"),
ScoreBiasOverride: nullableInt64Pointer(row.ScoreBiasOverride),
Expand Down Expand Up @@ -1673,6 +1676,44 @@ func (h *Handler) RefreshAccount(c *gin.Context) {
writeMessage(c, http.StatusOK, "账号刷新成功")
}

// ToggleAccountEnabled 切换账号是否参与调度选择
func (h *Handler) ToggleAccountEnabled(c *gin.Context) {
idStr := c.Param("id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
writeError(c, http.StatusBadRequest, "无效的账号 ID")
return
}

var req struct {
Enabled *bool `json:"enabled" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil || req.Enabled == nil {
writeError(c, http.StatusBadRequest, "请求格式错误")
return
}

ctx, cancel := context.WithTimeout(c.Request.Context(), 3*time.Second)
defer cancel()

if err := h.db.SetAccountEnabled(ctx, id, *req.Enabled); err != nil {
if err == sql.ErrNoRows {
writeError(c, http.StatusNotFound, "账号不存在")
return
}
writeError(c, http.StatusInternalServerError, "更新启用状态失败: "+err.Error())
return
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}

h.store.ApplyAccountEnabled(id, *req.Enabled)

if *req.Enabled {
writeMessage(c, http.StatusOK, "账号已启用")
} else {
writeMessage(c, http.StatusOK, "账号已禁用")
}
}

// ToggleAccountLock 切换账号的锁定状态
func (h *Handler) ToggleAccountLock(c *gin.Context) {
idStr := c.Param("id")
Expand Down
3 changes: 3 additions & 0 deletions auth/fast_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ func (a *Account) fastSchedulerSnapshot(baseLimit int64, now time.Time) (Account
}

available := a.Status != StatusError && tier != HealthTierBanned && a.AccessToken != ""
if atomic.LoadInt32(&a.DispatchPaused) != 0 {
available = false
}
if a.Status == StatusCooldown && now.Before(a.CooldownUtil) && !a.premium5hCooldownSuppressedLocked(now) {
available = false
}
Expand Down
108 changes: 108 additions & 0 deletions auth/fast_scheduler_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -37,6 +38,25 @@ func TestFastSchedulerAcquirePrefersHealthyTier(t *testing.T) {
}
}

func TestFastSchedulerSkipsDispatchPausedAccount(t *testing.T) {
paused := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 2)
atomic.StoreInt32(&paused.DispatchPaused, 1)
fallback := newFastSchedulerTestAccount(2, HealthTierHealthy, 80, 2)

scheduler := NewFastScheduler(2)
scheduler.Rebuild([]*Account{paused, fallback})

got := scheduler.Acquire()
if got == nil {
t.Fatal("Acquire() returned nil")
}
defer scheduler.Release(got)

if got.DBID != fallback.DBID {
t.Fatalf("Acquire() picked dbID=%d, want %d", got.DBID, fallback.DBID)
}
}

func TestFastSchedulerRespectsConcurrencyLimit(t *testing.T) {
acc := newFastSchedulerTestAccount(1, HealthTierHealthy, 100, 1)

Expand Down Expand Up @@ -108,6 +128,94 @@ func TestStoreNextExcludingRespectsAPIKeyWhitelist(t *testing.T) {
}
}

func TestStoreNextSkipsDispatchPausedAccount(t *testing.T) {
paused := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 1)
atomic.StoreInt32(&paused.DispatchPaused, 1)
fallback := newFastSchedulerTestAccount(2, HealthTierHealthy, 80, 1)

store := &Store{
accounts: []*Account{paused, fallback},
maxConcurrency: 1,
}

got := store.Next()
if got == nil {
t.Fatal("Next() returned nil")
}
defer store.Release(got)

if got.DBID != fallback.DBID {
t.Fatalf("Next() picked dbID=%d, want %d", got.DBID, fallback.DBID)
}
}

func TestDispatchPausedDoesNotBlockUsageProbe(t *testing.T) {
paused := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 1)
atomic.StoreInt32(&paused.DispatchPaused, 1)

store := &Store{
accounts: []*Account{paused},
}
var probed int32
store.SetUsageProbeFunc(func(_ context.Context, account *Account) error {
if account.DBID != paused.DBID {
t.Fatalf("usage probe account dbID=%d, want %d", account.DBID, paused.DBID)
}
atomic.AddInt32(&probed, 1)
return nil
})

store.parallelProbeUsage(context.Background())

if got := atomic.LoadInt32(&probed); got != 1 {
t.Fatalf("usage probe calls = %d, want 1", got)
}
}

func TestDispatchPausedDoesNotBlockRecoveryProbe(t *testing.T) {
paused := newFastSchedulerTestAccount(1, HealthTierBanned, 120, 1)
paused.RefreshToken = "rt"
paused.ExpiresAt = time.Now().Add(time.Hour)
atomic.StoreInt32(&paused.DispatchPaused, 1)

store := &Store{
accounts: []*Account{paused},
}
var probed int32
store.SetUsageProbeFunc(func(_ context.Context, account *Account) error {
if account.DBID != paused.DBID {
t.Fatalf("recovery probe account dbID=%d, want %d", account.DBID, paused.DBID)
}
atomic.AddInt32(&probed, 1)
return nil
})

store.parallelRecoveryProbe(context.Background())

if got := atomic.LoadInt32(&probed); got != 1 {
t.Fatalf("recovery probe calls = %d, want 1", got)
}
if atomic.LoadInt32(&paused.DispatchPaused) != 1 {
t.Fatal("recovery probe cleared DispatchPaused; enable/disable must remain independent")
}
}

func TestDispatchPausedDoesNotBlockAutoClean(t *testing.T) {
paused := newFastSchedulerTestAccount(1, HealthTierBanned, 120, 1)
atomic.StoreInt32(&paused.DispatchPaused, 1)

store := &Store{
accounts: []*Account{paused},
}

if cleaned := store.CleanByRuntimeStatus(context.Background(), "unauthorized"); cleaned != 1 {
t.Fatalf("CleanByRuntimeStatus cleaned %d accounts, want 1", cleaned)
}
if got := store.AccountCount(); got != 0 {
t.Fatalf("AccountCount() = %d, want 0", got)
}
}

func TestStoreNextExcludingWithFilterRespectsPlanFilter(t *testing.T) {
plus := newFastSchedulerTestAccount(1, HealthTierHealthy, 120, 1)
plus.PlanType = "plus"
Expand Down
21 changes: 21 additions & 0 deletions auth/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ type Account struct {
Disabled int32 // 原子标志,1 = 立即不可调度(401 时瞬间置位,无需等锁)
AddedAt int64 // 加入号池的时间(UnixNano),用于过期清理
Locked int32 // 原子标志,1 = 锁定,自动清理跳过此账号
DispatchPaused int32 // 原子标志,1 = 禁用调度选择,不影响刷新/探针/清理

// per-account 调度配置(nil = 跟随默认)
ScoreBiasOverride *int64
Expand Down Expand Up @@ -561,6 +562,9 @@ func (a *Account) IsAvailable() bool {
if atomic.LoadInt32(&a.Disabled) != 0 {
return false
}
if atomic.LoadInt32(&a.DispatchPaused) != 0 {
return false
}

a.mu.RLock()
defer a.mu.RUnlock()
Expand Down Expand Up @@ -1501,6 +1505,9 @@ func (s *Store) loadFromDB(ctx context.Context) error {
if row.Locked {
atomic.StoreInt32(&account.Locked, 1)
}
if !row.Enabled {
atomic.StoreInt32(&account.DispatchPaused, 1)
}
if row.Status == "error" {
account.Status = StatusError
account.ErrorMsg = row.ErrorMessage
Expand Down Expand Up @@ -2130,6 +2137,20 @@ func (s *Store) ApplyAccountAllowedAPIKeys(dbID int64, allowedAPIKeyIDs []int64)
return true
}

func (s *Store) ApplyAccountEnabled(dbID int64, enabled bool) bool {
acc := s.FindByID(dbID)
if acc == nil {
return false
}
if enabled {
atomic.StoreInt32(&acc.DispatchPaused, 0)
} else {
atomic.StoreInt32(&acc.DispatchPaused, 1)
}
s.fastSchedulerUpdate(acc)
return true
}

// MarkCooldown 标记账号进入冷却,并持久化到数据库
func (s *Store) MarkCooldown(acc *Account, duration time.Duration, reason string) {
if acc == nil {
Expand Down
24 changes: 22 additions & 2 deletions database/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type AccountRow struct {
CooldownReason string
CooldownUntil sql.NullTime
ErrorMessage string
Enabled bool
Locked bool
ScoreBiasOverride sql.NullInt64
BaseConcurrencyOverride sql.NullInt64
Expand Down Expand Up @@ -267,6 +268,7 @@ func (db *DB) migrate(ctx context.Context) error {

ALTER TABLE accounts ADD COLUMN IF NOT EXISTS cooldown_reason VARCHAR(50) DEFAULT '';
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS cooldown_until TIMESTAMPTZ NULL;
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS enabled BOOLEAN DEFAULT TRUE;
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS locked BOOLEAN DEFAULT FALSE;
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS score_bias_override INT NULL;
ALTER TABLE accounts ADD COLUMN IF NOT EXISTS base_concurrency_override INT NULL;
Expand Down Expand Up @@ -1849,7 +1851,7 @@ func (db *DB) GetAccountTimeRangeUsage(ctx context.Context, since time.Time) (ma
// ListActive 获取所有未删除账号。
func (db *DB) ListActive(ctx context.Context) ([]*AccountRow, error) {
query := `
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(enabled, true), COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
FROM accounts
WHERE status <> 'deleted' AND COALESCE(error_message, '') <> 'deleted'
ORDER BY id
Expand Down Expand Up @@ -1878,6 +1880,7 @@ func (db *DB) ListActive(ctx context.Context) ([]*AccountRow, error) {
&a.CooldownReason,
&cooldownUntilRaw,
&a.ErrorMessage,
&a.Enabled,
&a.Locked,
&a.ScoreBiasOverride,
&a.BaseConcurrencyOverride,
Expand Down Expand Up @@ -1907,7 +1910,7 @@ func (db *DB) ListActive(ctx context.Context) ([]*AccountRow, error) {
// GetAccountByID 获取未删除账号的完整数据库行。
func (db *DB) GetAccountByID(ctx context.Context, id int64) (*AccountRow, error) {
query := `
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
SELECT id, name, platform, type, credentials, proxy_url, status, cooldown_reason, cooldown_until, error_message, COALESCE(enabled, true), COALESCE(locked, false), score_bias_override, base_concurrency_override, created_at, updated_at
FROM accounts
WHERE id = $1 AND status <> 'deleted' AND COALESCE(error_message, '') <> 'deleted'
LIMIT 1
Expand All @@ -1928,6 +1931,7 @@ func (db *DB) GetAccountByID(ctx context.Context, id int64) (*AccountRow, error)
&a.CooldownReason,
&cooldownUntilRaw,
&a.ErrorMessage,
&a.Enabled,
&a.Locked,
&a.ScoreBiasOverride,
&a.BaseConcurrencyOverride,
Expand Down Expand Up @@ -2022,6 +2026,22 @@ func nullableInt64Value(v sql.NullInt64) interface{} {
return v.Int64
}

// SetAccountEnabled 设置账号是否参与调度选择
func (db *DB) SetAccountEnabled(ctx context.Context, id int64, enabled bool) error {
res, err := db.conn.ExecContext(ctx, `UPDATE accounts SET enabled = $1, updated_at = CURRENT_TIMESTAMP WHERE id = $2`, enabled, id)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return sql.ErrNoRows
}
return nil
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

// SetAccountLocked 设置账号的锁定状态
func (db *DB) SetAccountLocked(ctx context.Context, id int64, locked bool) error {
_, err := db.conn.ExecContext(ctx, `UPDATE accounts SET locked = $1 WHERE id = $2`, locked, id)
Expand Down
1 change: 1 addition & 0 deletions database/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ func (db *DB) migrateSQLite(ctx context.Context) error {
{"system_settings", "prompt_filter_sensitive_words", "TEXT DEFAULT ''"},
{"system_settings", "prompt_filter_custom_patterns", "TEXT DEFAULT '[]'"},
{"system_settings", "prompt_filter_disabled_patterns", "TEXT DEFAULT '[]'"},
{"accounts", "enabled", "INTEGER DEFAULT 1"},
{"accounts", "locked", "INTEGER DEFAULT 0"},
{"accounts", "image_quota_remaining", "INTEGER NULL"},
{"accounts", "image_quota_total", "INTEGER NULL"},
Expand Down
45 changes: 45 additions & 0 deletions database/sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,51 @@ func TestNewSQLiteInitializesFreshDatabase(t *testing.T) {
}
}

func TestSQLiteAccountsEnabledDefaultsAndCanToggle(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "codex2api.db")

db, err := New("sqlite", dbPath)
if err != nil {
t.Fatalf("New(sqlite) 返回错误: %v", err)
}
defer db.Close()

ctx := context.Background()
id, err := db.InsertAccount(ctx, "test", "rt", "")
if err != nil {
t.Fatalf("InsertAccount 返回错误: %v", err)
}

rows, err := db.ListActive(ctx)
if err != nil {
t.Fatalf("ListActive 返回错误: %v", err)
}
if len(rows) != 1 {
t.Fatalf("ListActive 返回 %d 条,want 1", len(rows))
}
if !rows[0].Enabled {
t.Fatal("new account Enabled = false, want true")
}

if err := db.SetAccountEnabled(ctx, id, false); err != nil {
t.Fatalf("SetAccountEnabled 返回错误: %v", err)
}
rows, err = db.ListActive(ctx)
if err != nil {
t.Fatalf("ListActive 返回错误: %v", err)
}
if len(rows) != 1 {
t.Fatalf("ListActive 返回 %d 条,want 1", len(rows))
}
if rows[0].Enabled {
t.Fatal("disabled account Enabled = true, want false")
}

if err := db.SetAccountEnabled(ctx, id+1, false); err != sql.ErrNoRows {
t.Fatalf("SetAccountEnabled missing account error = %v, want sql.ErrNoRows", err)
}
}

func TestSQLiteUsageLogsHasAPIKeyColumns(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "codex2api.db")

Expand Down
2 changes: 2 additions & 0 deletions frontend/src/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ export const api = {
request<MessageResponse>(`/accounts/${id}/refresh`, { method: 'POST' }),
updateAccountScheduler: (id: number, data: UpdateAccountSchedulerRequest) =>
request<MessageResponse>(`/accounts/${id}/scheduler`, { method: 'PATCH', body: JSON.stringify(data) }),
toggleAccountEnabled: (id: number, enabled: boolean) =>
request<MessageResponse>(`/accounts/${id}/enable`, { method: 'POST', body: JSON.stringify({ enabled }) }),
toggleAccountLock: (id: number, locked: boolean) =>
request<MessageResponse>(`/accounts/${id}/lock`, { method: 'POST', body: JSON.stringify({ locked }) }),
resetAccountStatus: (id: number) =>
Expand Down
Loading
Loading