From 6b51f00122196d32956ac9decdad4889f949715d Mon Sep 17 00:00:00 2001 From: DearGlory <2814703211@qq.com> Date: Mon, 18 May 2026 14:30:26 +0800 Subject: [PATCH] fix: sync codex plan type on account reset --- admin/handler.go | 66 ++++++++++++++++++++++++++++-------- admin/handler_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++ auth/store.go | 33 ++++++++++++++++++ proxy/handler.go | 3 ++ proxy/handler_test.go | 44 ++++++++++++++++++++++++ 5 files changed, 210 insertions(+), 15 deletions(-) diff --git a/admin/handler.go b/admin/handler.go index 49926bcc..b0016a9c 100644 --- a/admin/handler.go +++ b/admin/handler.go @@ -38,21 +38,22 @@ import ( // Handler 管理后台 API 处理器 type Handler struct { - store *auth.Store - cache cache.TokenCache - db *database.DB - rateLimiter *proxy.RateLimiter - refreshAccount func(context.Context, int64) error - cpuSampler *cpuSampler - startedAt time.Time - pgMaxConns int - redisPoolSize int - databaseDriver string - databaseLabel string - cacheDriver string - cacheLabel string - adminSecretEnv string - imageProxy *proxy.Handler + store *auth.Store + cache cache.TokenCache + db *database.DB + rateLimiter *proxy.RateLimiter + refreshAccount func(context.Context, int64) error + syncAccountPlanOnReset func(context.Context, *auth.Account) error + cpuSampler *cpuSampler + startedAt time.Time + pgMaxConns int + redisPoolSize int + databaseDriver string + databaseLabel string + cacheDriver string + cacheLabel string + adminSecretEnv string + imageProxy *proxy.Handler // 图表聚合内存缓存(10秒 TTL) chartCacheMu sync.RWMutex @@ -170,6 +171,7 @@ func NewHandler(store *auth.Store, db *database.DB, tc cache.TokenCache, rl *pro handler.imageProxy.SetRuntimeCache(tc) } handler.refreshAccount = handler.refreshSingleAccount + handler.syncAccountPlanOnReset = handler.syncSingleAccountPlanOnReset if db != nil { if err := db.MarkInterruptedImageJobs(context.Background()); err != nil { log.Printf("标记中断生图任务失败: %v", err) @@ -2453,6 +2455,7 @@ func (h *Handler) ResetAccountStatus(c *gin.Context) { h.store.ClearCooldown(acc) acc.ClearUsageCache() + h.syncAccountPlanAfterReset(c.Request.Context(), acc) writeMessage(c, http.StatusOK, "账号状态已重置") } @@ -2476,6 +2479,7 @@ func (h *Handler) BatchResetStatus(c *gin.Context) { } h.store.ClearCooldown(acc) acc.ClearUsageCache() + h.syncAccountPlanAfterReset(c.Request.Context(), acc) success++ } @@ -2486,6 +2490,38 @@ func (h *Handler) BatchResetStatus(c *gin.Context) { }) } +func (h *Handler) syncAccountPlanAfterReset(parent context.Context, acc *auth.Account) { + if h == nil || h.syncAccountPlanOnReset == nil || acc == nil { + return + } + if parent == nil { + parent = context.Background() + } + ctx, cancel := context.WithTimeout(parent, 15*time.Second) + defer cancel() + if err := h.syncAccountPlanOnReset(ctx, acc); err != nil { + log.Printf("[account %d] sync Codex plan type after reset failed: %v", acc.DBID, err) + } +} + +func (h *Handler) syncSingleAccountPlanOnReset(ctx context.Context, acc *auth.Account) error { + if h == nil || h.store == nil || acc == nil || acc.IsOpenAIResponsesAPI() || acc.GetAccessToken() == "" { + return nil + } + model, err := h.connectionTestModelForAccount(ctx, acc, "") + if err != nil { + return err + } + resp, err := proxy.ExecuteRequest(ctx, acc, buildTestPayload(model), "", h.store.ResolveProxyForAccount(acc), "", nil, nil) + if err != nil { + return err + } + defer resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + proxy.SyncCodexUsageState(h.store, acc, resp) + return nil +} + func (h *Handler) refreshSingleAccount(ctx context.Context, id int64) error { if h == nil || h.store == nil { return fmt.Errorf("账号池未初始化") diff --git a/admin/handler_test.go b/admin/handler_test.go index 346e1acf..d4431a7f 100644 --- a/admin/handler_test.go +++ b/admin/handler_test.go @@ -147,6 +147,85 @@ func TestRefreshAccountReturnsRefreshFailure(t *testing.T) { } } +func TestResetAccountStatusSyncsPlanMetadata(t *testing.T) { + gin.SetMode(gin.TestMode) + + store := auth.NewStore(nil, nil, nil) + account := &auth.Account{DBID: 42, AccessToken: "at", PlanType: "free"} + account.SetUsageSnapshot(88, time.Now().Add(time.Hour)) + store.AddAccount(account) + + var called bool + handler := &Handler{ + store: store, + syncAccountPlanOnReset: func(_ context.Context, acc *auth.Account) error { + called = true + if acc == nil || acc.DBID != 42 { + t.Fatalf("sync account = %#v, want DBID 42", acc) + } + return nil + }, + } + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Params = gin.Params{{Key: "id", Value: "42"}} + ctx.Request = httptest.NewRequest(http.MethodPost, "/api/admin/accounts/42/reset-status", nil) + + handler.ResetAccountStatus(ctx) + + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", recorder.Code, http.StatusOK, recorder.Body.String()) + } + if !called { + t.Fatal("expected reset to sync plan metadata") + } + if _, ok := account.GetUsagePercent7d(); ok { + t.Fatal("expected reset to clear cached usage") + } +} + +func TestBatchResetStatusSyncsEachResolvedAccount(t *testing.T) { + gin.SetMode(gin.TestMode) + + store := auth.NewStore(nil, nil, nil) + store.AddAccount(&auth.Account{DBID: 11, AccessToken: "at-11", PlanType: "free"}) + store.AddAccount(&auth.Account{DBID: 22, AccessToken: "at-22", PlanType: "plus"}) + + var gotIDs []int64 + handler := &Handler{ + store: store, + syncAccountPlanOnReset: func(_ context.Context, acc *auth.Account) error { + gotIDs = append(gotIDs, acc.DBID) + if acc.DBID == 22 { + return errors.New("temporary upstream failure") + } + return nil + }, + } + + recorder := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(recorder) + ctx.Request = httptest.NewRequest(http.MethodPost, "/api/admin/accounts/batch-reset-status", strings.NewReader(`{"ids":[11,99,22]}`)) + ctx.Request.Header.Set("Content-Type", "application/json") + + handler.BatchResetStatus(ctx) + + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d, want %d, body=%s", recorder.Code, http.StatusOK, recorder.Body.String()) + } + if fmt.Sprint(gotIDs) != "[11 22]" { + t.Fatalf("synced ids = %v, want [11 22]", gotIDs) + } + var payload map[string]interface{} + if err := json.Unmarshal(recorder.Body.Bytes(), &payload); err != nil { + t.Fatalf("decode response: %v", err) + } + if payload["success"] != float64(2) || payload["failed"] != float64(1) { + t.Fatalf("payload = %#v, want success=2 failed=1", payload) + } +} + func TestCreateAPIKeyPersistsQuotaAndExpiration(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/auth/store.go b/auth/store.go index 4420b744..34cd74a3 100644 --- a/auth/store.go +++ b/auth/store.go @@ -3559,6 +3559,39 @@ func (s *Store) PersistUsageSnapshot(acc *Account, pct7d float64) { } } +// UpdateAccountPlanType persists the latest Codex plan type observed from upstream headers. +func (s *Store) UpdateAccountPlanType(acc *Account, planType string) bool { + if s == nil || acc == nil { + return false + } + plan := strings.ToLower(strings.TrimSpace(planType)) + if plan == "" { + return false + } + + acc.mu.Lock() + changed := acc.PlanType != plan + if changed { + acc.PlanType = plan + acc.recomputeSchedulerLocked(atomic.LoadInt64(&s.maxConcurrency)) + } + acc.mu.Unlock() + if changed { + s.fastSchedulerUpdate(acc) + } + + if s.db == nil || !changed { + return changed + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := s.db.UpdateCredentials(ctx, acc.DBID, map[string]interface{}{"plan_type": plan}); err != nil { + log.Printf("[璐﹀彿 %d] 鎸佷箙鍖?plan_type 澶辫触: %v", acc.DBID, err) + } + return changed +} + // ApplyUsageLimitMetadata applies metadata returned by Codex usage_limit_reached errors. func (s *Store) ApplyUsageLimitMetadata(acc *Account, planType string, resetAt time.Time) { if acc == nil { diff --git a/proxy/handler.go b/proxy/handler.go index 984e45e9..dceb82f3 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -2953,6 +2953,9 @@ func SyncCodexUsageState(store *auth.Store, account *auth.Account, resp *http.Re if account == nil || resp == nil { return result } + if store != nil { + store.UpdateAccountPlanType(account, resp.Header.Get("x-codex-plan-type")) + } result.Used5hHeaders = responseHasCodex5hHeaders(resp) result.UsagePct7d, result.HasUsage7d = parseCodexUsageHeaders(resp, account) diff --git a/proxy/handler_test.go b/proxy/handler_test.go index c6584cc6..e49e4389 100644 --- a/proxy/handler_test.go +++ b/proxy/handler_test.go @@ -769,6 +769,50 @@ func TestApply429CooldownUsageLimitUpdatesFreePlanMetadata(t *testing.T) { } } +func TestSyncCodexUsageStateUpdatesPlanTypeFromHeader(t *testing.T) { + ctx := context.Background() + dbPath := filepath.Join(t.TempDir(), "codex2api.db") + db, err := database.New("sqlite", dbPath) + if err != nil { + t.Fatalf("database.New returned error: %v", err) + } + defer db.Close() + + id, err := db.InsertAccountWithCredentials(ctx, "plan-header-account", map[string]interface{}{ + "plan_type": "free", + }, "") + if err != nil { + t.Fatalf("InsertAccountWithCredentials returned error: %v", err) + } + + store := auth.NewStore(db, nil, &database.SystemSettings{MaxConcurrency: 2, TestConcurrency: 1, TestModel: "gpt-5.4"}) + account := &auth.Account{DBID: id, AccessToken: "at", PlanType: "free"} + resp := &http.Response{Header: make(http.Header)} + resp.Header.Set("x-codex-plan-type", "Enterprise") + resp.Header.Set("x-codex-primary-used-percent", "12") + resp.Header.Set("x-codex-primary-window-minutes", "300") + resp.Header.Set("x-codex-primary-reset-after-seconds", "1200") + resp.Header.Set("x-codex-secondary-used-percent", "3") + resp.Header.Set("x-codex-secondary-window-minutes", "10080") + resp.Header.Set("x-codex-secondary-reset-after-seconds", "600000") + + result := SyncCodexUsageState(store, account, resp) + + if got := account.GetPlanType(); got != "enterprise" { + t.Fatalf("account plan_type = %q, want enterprise", got) + } + if !result.Used5hHeaders || !result.HasUsage5h || !result.HasUsage7d { + t.Fatalf("usage sync result = %#v, want 5h and 7d headers detected", result) + } + row, err := db.GetAccountByID(ctx, id) + if err != nil { + t.Fatalf("GetAccountByID returned error: %v", err) + } + if got := row.GetCredential("plan_type"); got != "enterprise" { + t.Fatalf("persisted plan_type = %q, want enterprise", got) + } +} + func TestApply429CooldownUnknown429UsesModelCooldown(t *testing.T) { store := auth.NewStore(nil, nil, &database.SystemSettings{MaxConcurrency: 2, TestConcurrency: 1, TestModel: "gpt-5.4"}) account := &auth.Account{DBID: 102, PlanType: "pro"}