Skip to content

Commit ae09c9b

Browse files
committed
fix(admin): record unauthorized errors and reduce batch render load
1 parent 78cf742 commit ae09c9b

9 files changed

Lines changed: 307 additions & 60 deletions

File tree

admin/batch_test_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,39 @@ func TestRunSingleBatchTestSuccessRecoversBannedAccount(t *testing.T) {
211211
t.Fatal("successful batch test should record scheduler success")
212212
}
213213
}
214+
215+
func TestRunSingleBatchTestUnauthorizedRecordsErrorMessage(t *testing.T) {
216+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
217+
w.Header().Set("Content-Type", "application/json")
218+
w.WriteHeader(http.StatusUnauthorized)
219+
_, _ = w.Write([]byte(`{"error":{"code":"token_invalidated","message":"token invalidated"},"status":401}`))
220+
}))
221+
defer server.Close()
222+
223+
store := auth.NewStore(nil, nil, nil)
224+
account := &auth.Account{
225+
DBID: 1,
226+
UpstreamType: auth.UpstreamOpenAIResponses,
227+
BaseURL: server.URL,
228+
APIKey: "test-key",
229+
Models: []string{"gpt-4o-mini"},
230+
Status: auth.StatusReady,
231+
HealthTier: auth.HealthTierHealthy,
232+
}
233+
store.AddAccount(account)
234+
handler := &Handler{store: store}
235+
236+
status, msg := handler.runSingleBatchTest(context.Background(), account)
237+
if status != "banned" {
238+
t.Fatalf("status = %q, message = %q, want banned", status, msg)
239+
}
240+
if got := account.RuntimeStatus(); got != "unauthorized" {
241+
t.Fatalf("RuntimeStatus() = %q, want unauthorized", got)
242+
}
243+
account.Mu().RLock()
244+
errorMsg := account.ErrorMsg
245+
account.Mu().RUnlock()
246+
if !strings.Contains(errorMsg, "token_invalidated") {
247+
t.Fatalf("ErrorMsg = %q, want token_invalidated", errorMsg)
248+
}
249+
}

admin/test_connection.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,18 @@ func (h *Handler) TestConnection(c *gin.Context) {
9494
proxy.SyncCodexUsageState(h.store, account, resp)
9595
}
9696
errBody, _ := io.ReadAll(resp.Body)
97+
errMsg := fmt.Sprintf("上游返回 %d: %s", resp.StatusCode, truncate(string(errBody), 500))
9798
switch resp.StatusCode {
9899
case http.StatusUnauthorized:
99-
h.store.MarkCooldown(account, 24*time.Hour, "unauthorized")
100+
h.store.MarkCooldownWithError(account, 24*time.Hour, "unauthorized", errMsg)
100101
case http.StatusTooManyRequests:
101102
if isOpenAIResponsesAccount {
102103
h.store.MarkCooldown(account, time.Minute, "rate_limited")
103104
} else {
104105
proxy.Apply429Cooldown(h.store, account, errBody, resp, testModel)
105106
}
106107
}
107-
sendTestEvent(c, testEvent{Type: "error", Error: fmt.Sprintf("上游返回 %d: %s", resp.StatusCode, truncate(string(errBody), 500))})
108+
sendTestEvent(c, testEvent{Type: "error", Error: errMsg})
108109
return
109110
}
110111

@@ -754,7 +755,7 @@ func (h *Handler) runSingleBatchTest(ctx context.Context, acc *auth.Account) (st
754755
if !acc.IsOpenAIResponsesAPI() {
755756
proxy.SyncCodexUsageState(h.store, acc, resp)
756757
}
757-
h.store.MarkCooldown(acc, 24*time.Hour, "unauthorized")
758+
h.store.MarkCooldownWithError(acc, 24*time.Hour, "unauthorized", fmt.Sprintf("上游返回 %d: %s", resp.StatusCode, truncate(string(body), 300)))
758759
return "banned", "账号授权失败"
759760
case http.StatusTooManyRequests:
760761
if acc.IsOpenAIResponsesAPI() {

admin/test_connection_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package admin
22

33
import (
4+
"net/http"
5+
"net/http/httptest"
46
"strings"
57
"testing"
68
"time"
79

10+
"github.com/codex2api/auth"
811
"github.com/codex2api/proxy"
12+
"github.com/gin-gonic/gin"
913
"github.com/tidwall/gjson"
1014
)
1115

@@ -48,6 +52,52 @@ func TestFormatUsageLimitedTestErrorReportsSuccessfulProbeAsLimited(t *testing.T
4852
}
4953
}
5054

55+
func TestConnectionUnauthorizedRecordsErrorMessage(t *testing.T) {
56+
gin.SetMode(gin.TestMode)
57+
upstreamBody := `{"error":{"message":"Your authentication token has been invalidated.","code":"token_invalidated"},"status":401}`
58+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
59+
w.Header().Set("Content-Type", "application/json")
60+
w.WriteHeader(http.StatusUnauthorized)
61+
_, _ = w.Write([]byte(upstreamBody))
62+
}))
63+
defer server.Close()
64+
65+
store := auth.NewStore(nil, nil, nil)
66+
account := &auth.Account{
67+
DBID: 42,
68+
UpstreamType: auth.UpstreamOpenAIResponses,
69+
BaseURL: server.URL,
70+
APIKey: "sk-test",
71+
Models: []string{"gpt-4o-mini"},
72+
Status: auth.StatusReady,
73+
HealthTier: auth.HealthTierHealthy,
74+
}
75+
store.AddAccount(account)
76+
handler := &Handler{store: store}
77+
router := gin.New()
78+
router.GET("/api/admin/accounts/:id/test", handler.TestConnection)
79+
80+
recorder := httptest.NewRecorder()
81+
request := httptest.NewRequest(http.MethodGet, "/api/admin/accounts/42/test", nil)
82+
router.ServeHTTP(recorder, request)
83+
84+
if recorder.Code != http.StatusOK {
85+
t.Fatalf("status = %d, want 200", recorder.Code)
86+
}
87+
if !strings.Contains(recorder.Body.String(), "token_invalidated") {
88+
t.Fatalf("SSE response %q does not contain token_invalidated", recorder.Body.String())
89+
}
90+
if got := account.RuntimeStatus(); got != "unauthorized" {
91+
t.Fatalf("RuntimeStatus() = %q, want unauthorized", got)
92+
}
93+
account.Mu().RLock()
94+
errorMsg := account.ErrorMsg
95+
account.Mu().RUnlock()
96+
if !strings.Contains(errorMsg, "token_invalidated") {
97+
t.Fatalf("ErrorMsg = %q, want token_invalidated", errorMsg)
98+
}
99+
}
100+
51101
func TestExtractCompletedOutputText(t *testing.T) {
52102
event := []byte(`{
53103
"type":"response.completed",

admin/usage_probe.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@ func (h *Handler) ProbeUsageSnapshot(ctx context.Context, account *auth.Account)
4646
func (h *Handler) probeUsageViaWham(ctx context.Context, account *auth.Account) error {
4747
usage, resp, err := proxy.QueryWhamUsage(ctx, account, h.store.ResolveProxyForAccount(account))
4848
if resp != nil {
49-
// QueryWhamUsage 在非 200 时不会读 body;这里关闭,并按状态码做冷却
50-
_, _ = io.Copy(io.Discard, resp.Body)
49+
// QueryWhamUsage 在非 200 时不会读 body;这里读取一小段用于账号错误详情。
50+
body, _ := io.ReadAll(io.LimitReader(resp.Body, 64<<10))
5151
_ = resp.Body.Close()
5252
switch resp.StatusCode {
5353
case http.StatusUnauthorized:
5454
h.store.ReportRequestFailure(account, "client", 0)
55-
h.store.MarkCooldown(account, 24*time.Hour, "unauthorized")
55+
h.store.MarkCooldownWithError(account, 24*time.Hour, "unauthorized", fmt.Sprintf("用量探针 wham 上游返回 %d: %s", resp.StatusCode, truncate(string(body), 300)))
5656
case http.StatusTooManyRequests:
5757
h.store.ReportRequestFailure(account, "client", 0)
5858
}
@@ -97,7 +97,7 @@ func (h *Handler) probeUsageViaResponses(ctx context.Context, account *auth.Acco
9797
return nil
9898
case http.StatusUnauthorized:
9999
h.store.ReportRequestFailure(account, "client", 0)
100-
h.store.MarkCooldown(account, 24*time.Hour, "unauthorized")
100+
h.store.MarkCooldownWithError(account, 24*time.Hour, "unauthorized", fmt.Sprintf("用量探针上游返回 %d: %s", resp.StatusCode, truncate(string(body), 300)))
101101
return nil
102102
case http.StatusTooManyRequests:
103103
h.store.ReportRequestFailure(account, "client", 0)

auth/runtime_status_test.go

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
package auth
22

3-
import "testing"
3+
import (
4+
"strings"
5+
"testing"
6+
"time"
7+
)
48

59
func TestRuntimeStatusShowsRefreshingForRTWithoutAccessToken(t *testing.T) {
610
acc := &Account{
@@ -43,3 +47,34 @@ func TestMarkErrorAndClearCooldownRoundTrip(t *testing.T) {
4347
t.Fatalf("RuntimeStatus() after ClearCooldown = %q, want active", got)
4448
}
4549
}
50+
51+
func TestMarkCooldownWithErrorKeepsUnauthorizedStatusAndMessage(t *testing.T) {
52+
store := NewStore(nil, nil, nil)
53+
acc := &Account{
54+
DBID: 1,
55+
AccessToken: "at-test",
56+
Status: StatusReady,
57+
HealthTier: HealthTierHealthy,
58+
}
59+
60+
store.MarkCooldownWithError(acc, 24*time.Hour, "unauthorized", "上游返回 401: token_invalidated")
61+
62+
if got := acc.RuntimeStatus(); got != "unauthorized" {
63+
t.Fatalf("RuntimeStatus() = %q, want unauthorized", got)
64+
}
65+
acc.Mu().RLock()
66+
errorMsg := acc.ErrorMsg
67+
cooldownReason := acc.CooldownReason
68+
cooldownUntil := acc.CooldownUtil
69+
status := acc.Status
70+
acc.Mu().RUnlock()
71+
if status != StatusCooldown {
72+
t.Fatalf("Status = %v, want cooldown", status)
73+
}
74+
if cooldownReason != "unauthorized" || cooldownUntil.IsZero() {
75+
t.Fatalf("cooldown = (%q, %s), want unauthorized with deadline", cooldownReason, cooldownUntil)
76+
}
77+
if !strings.Contains(errorMsg, "token_invalidated") {
78+
t.Fatalf("ErrorMsg = %q, want token_invalidated", errorMsg)
79+
}
80+
}

auth/store.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3770,12 +3770,33 @@ func (s *Store) ApplyAccountEnabled(dbID int64, enabled bool) bool {
37703770
return true
37713771
}
37723772

3773+
func normalizeAccountErrorMessage(errorMsg string, fallback string) string {
3774+
errorMsg = strings.TrimSpace(errorMsg)
3775+
if errorMsg == "" {
3776+
errorMsg = strings.TrimSpace(fallback)
3777+
}
3778+
if len(errorMsg) > 500 {
3779+
errorMsg = errorMsg[:500]
3780+
}
3781+
return errorMsg
3782+
}
3783+
37733784
// MarkCooldown 标记账号进入冷却,并持久化到数据库
37743785
func (s *Store) MarkCooldown(acc *Account, duration time.Duration, reason string) {
3786+
s.markCooldown(acc, duration, reason, "")
3787+
}
3788+
3789+
// MarkCooldownWithError 标记账号进入冷却,并同时记录本次上游错误详情。
3790+
func (s *Store) MarkCooldownWithError(acc *Account, duration time.Duration, reason string, errorMsg string) {
3791+
s.markCooldown(acc, duration, reason, errorMsg)
3792+
}
3793+
3794+
func (s *Store) markCooldown(acc *Account, duration time.Duration, reason string, errorMsg string) {
37753795
if acc == nil {
37763796
return
37773797
}
37783798

3799+
errorMsg = normalizeAccountErrorMessage(errorMsg, "")
37793800
now := time.Now()
37803801
acc.mu.Lock()
37813802
switch reason {
@@ -3801,6 +3822,9 @@ func (s *Store) MarkCooldown(acc *Account, duration time.Duration, reason string
38013822
acc.HealthTier = HealthTierRisky
38023823
}
38033824
}
3825+
if errorMsg != "" {
3826+
acc.ErrorMsg = errorMsg
3827+
}
38043828
acc.recomputeSchedulerLocked(atomic.LoadInt64(&s.maxConcurrency))
38053829
acc.mu.Unlock()
38063830

@@ -3815,7 +3839,13 @@ func (s *Store) MarkCooldown(acc *Account, duration time.Duration, reason string
38153839

38163840
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
38173841
defer cancel()
3818-
if err := s.db.SetCooldown(ctx, acc.DBID, reason, until); err != nil {
3842+
var err error
3843+
if errorMsg != "" {
3844+
err = s.db.SetCooldownWithError(ctx, acc.DBID, reason, until, errorMsg)
3845+
} else {
3846+
err = s.db.SetCooldown(ctx, acc.DBID, reason, until)
3847+
}
3848+
if err != nil {
38193849
log.Printf("[账号 %d] 持久化冷却状态失败: %v", acc.DBID, err)
38203850
}
38213851
}
@@ -3917,13 +3947,7 @@ func (s *Store) MarkError(acc *Account, errorMsg string) {
39173947
return
39183948
}
39193949

3920-
errorMsg = strings.TrimSpace(errorMsg)
3921-
if errorMsg == "" {
3922-
errorMsg = "账号测试失败"
3923-
}
3924-
if len(errorMsg) > 500 {
3925-
errorMsg = errorMsg[:500]
3926-
}
3950+
errorMsg = normalizeAccountErrorMessage(errorMsg, "账号测试失败")
39273951

39283952
now := time.Now()
39293953
acc.mu.Lock()

database/postgres.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,6 +3837,13 @@ func (db *DB) SetCooldown(ctx context.Context, id int64, reason string, until ti
38373837
return err
38383838
}
38393839

3840+
// SetCooldownWithError 持久化账号冷却状态,并保留本次错误详情。
3841+
func (db *DB) SetCooldownWithError(ctx context.Context, id int64, reason string, until time.Time, errorMsg string) error {
3842+
query := `UPDATE accounts SET cooldown_reason = $1, cooldown_until = $2, error_message = $3, updated_at = CURRENT_TIMESTAMP WHERE id = $4`
3843+
_, err := db.conn.ExecContext(ctx, query, reason, until, errorMsg, id)
3844+
return err
3845+
}
3846+
38403847
// ClearCooldown 清除账号冷却状态
38413848
func (db *DB) ClearCooldown(ctx context.Context, id int64) error {
38423849
query := `UPDATE accounts SET cooldown_reason = '', cooldown_until = NULL, updated_at = CURRENT_TIMESTAMP WHERE id = $1`

database/sqlite_test.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,42 @@ func TestListActiveIncludesErrorAccounts(t *testing.T) {
12621262
}
12631263
}
12641264

1265+
func TestSetCooldownWithErrorPersistsMessage(t *testing.T) {
1266+
dbPath := filepath.Join(t.TempDir(), "codex2api.db")
1267+
1268+
db, err := New("sqlite", dbPath)
1269+
if err != nil {
1270+
t.Fatalf("New(sqlite) 返回错误: %v", err)
1271+
}
1272+
defer db.Close()
1273+
1274+
ctx := context.Background()
1275+
id, err := db.InsertAccount(ctx, "cooldown-account", "rt-cooldown", "")
1276+
if err != nil {
1277+
t.Fatalf("InsertAccount 返回错误: %v", err)
1278+
}
1279+
until := time.Now().Add(time.Hour)
1280+
if err := db.SetCooldownWithError(ctx, id, "unauthorized", until, "上游返回 401: token_invalidated"); err != nil {
1281+
t.Fatalf("SetCooldownWithError 返回错误: %v", err)
1282+
}
1283+
1284+
var reason string
1285+
var errorMessage string
1286+
var cooldownUntil sql.NullTime
1287+
if err := db.conn.QueryRowContext(ctx, `SELECT cooldown_reason, error_message, cooldown_until FROM accounts WHERE id = $1`, id).Scan(&reason, &errorMessage, &cooldownUntil); err != nil {
1288+
t.Fatalf("查询账号冷却状态返回错误: %v", err)
1289+
}
1290+
if reason != "unauthorized" {
1291+
t.Fatalf("cooldown_reason = %q, want unauthorized", reason)
1292+
}
1293+
if errorMessage != "上游返回 401: token_invalidated" {
1294+
t.Fatalf("error_message = %q, want recorded upstream error", errorMessage)
1295+
}
1296+
if !cooldownUntil.Valid {
1297+
t.Fatal("cooldown_until 未写入")
1298+
}
1299+
}
1300+
12651301
func TestUsageLogsFilterByAPIKeyID(t *testing.T) {
12661302
dbPath := filepath.Join(t.TempDir(), "codex2api.db")
12671303

0 commit comments

Comments
 (0)