Skip to content

Commit 852a703

Browse files
Fix review issues in docs and scheduler management
1 parent 28d317c commit 852a703

15 files changed

Lines changed: 7549 additions & 4629 deletions

File tree

admin/account_groups.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ func (h *Handler) UpdateAccountGroup(c *gin.Context) {
158158
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
159159
defer cancel()
160160
if err := h.db.UpdateAccountGroup(ctx, id, req.Name, req.Description, req.Color, req.SortOrder); err != nil {
161-
if err == sql.ErrNoRows {
161+
if errors.Is(err, sql.ErrNoRows) {
162162
writeError(c, http.StatusNotFound, "分组不存在")
163163
return
164164
}
@@ -182,7 +182,7 @@ func (h *Handler) DeleteAccountGroup(c *gin.Context) {
182182
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
183183
defer cancel()
184184
if err := h.db.DeleteAccountGroup(ctx, id, force); err != nil {
185-
if err == sql.ErrNoRows {
185+
if errors.Is(err, sql.ErrNoRows) {
186186
writeError(c, http.StatusNotFound, "分组不存在")
187187
return
188188
}

admin/handler.go

Lines changed: 17 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"encoding/base64"
1010
"encoding/hex"
1111
"encoding/json"
12+
"errors"
1213
"fmt"
1314
"io"
1415
"log"
@@ -720,8 +721,12 @@ func (h *Handler) UpdateAccountScheduler(c *gin.Context) {
720721
}
721722
}
722723

723-
if err := h.db.UpdateAccountSchedulerConfig(ctx, id, scoreBiasOverride, baseConcurrencyOverride, allowedAPIKeyIDs); err != nil {
724-
if err == sql.ErrNoRows {
724+
proxyURL := database.OptionalString{}
725+
if req.ProxyURL != nil {
726+
proxyURL = database.OptionalString{Set: true, Value: *req.ProxyURL}
727+
}
728+
if err := h.db.UpdateAccountSchedulerMetadata(ctx, id, scoreBiasOverride, baseConcurrencyOverride, allowedAPIKeyIDs, database.OptionalStringSlice{Set: tags.Set, Values: tags.Values}, groupIDs, proxyURL); err != nil {
729+
if errors.Is(err, sql.ErrNoRows) {
725730
writeError(c, http.StatusNotFound, "账号不存在")
726731
return
727732
}
@@ -757,32 +762,14 @@ func (h *Handler) UpdateAccountScheduler(c *gin.Context) {
757762
h.store.ApplyAccountAllowedAPIKeys(id, allowedAPIKeyIDs.Values)
758763
}
759764
}
760-
if tags.Set {
761-
if err := h.db.UpdateAccountTags(ctx, id, tags.Values); err != nil {
762-
writeError(c, http.StatusInternalServerError, "更新账号标签失败: "+err.Error())
763-
return
764-
}
765-
if h.store != nil {
766-
h.store.ApplyAccountTags(id, tags.Values)
767-
}
765+
if h.store != nil && tags.Set {
766+
h.store.ApplyAccountTags(id, tags.Values)
768767
}
769-
if groupIDs.Set {
770-
if err := h.db.SetAccountGroups(ctx, id, groupIDs.Values); err != nil {
771-
writeError(c, http.StatusInternalServerError, "更新账号分组失败: "+err.Error())
772-
return
773-
}
774-
if h.store != nil {
775-
h.store.ApplyAccountGroups(id, groupIDs.Values)
776-
}
768+
if h.store != nil && groupIDs.Set {
769+
h.store.ApplyAccountGroups(id, groupIDs.Values)
777770
}
778-
if req.ProxyURL != nil {
779-
if err := h.db.UpdateAccountProxyURL(ctx, id, *req.ProxyURL); err != nil {
780-
writeError(c, http.StatusInternalServerError, "更新账号代理失败: "+err.Error())
781-
return
782-
}
783-
if h.store != nil {
784-
h.store.ApplyAccountProxyURL(id, *req.ProxyURL)
785-
}
771+
if h.store != nil && req.ProxyURL != nil {
772+
h.store.ApplyAccountProxyURL(id, *req.ProxyURL)
786773
}
787774

788775
writeMessage(c, http.StatusOK, "账号调度配置已更新")
@@ -3932,13 +3919,13 @@ func (h *Handler) UpdateSettings(c *gin.Context) {
39323919
}
39333920

39343921
if req.ProxyPoolEnabled != nil {
3935-
h.store.SetProxyPoolEnabled(*req.ProxyPoolEnabled)
39363922
if *req.ProxyPoolEnabled {
39373923
if err := h.store.ReloadProxyPool(); err != nil {
39383924
writeError(c, http.StatusInternalServerError, "代理池刷新失败: "+err.Error())
39393925
return
39403926
}
39413927
}
3928+
h.store.SetProxyPoolEnabled(*req.ProxyPoolEnabled)
39423929
log.Printf("设置已更新: proxy_pool_enabled = %t", *req.ProxyPoolEnabled)
39433930
}
39443931

@@ -4758,10 +4745,8 @@ func (h *Handler) AddProxies(c *gin.Context) {
47584745
return
47594746
}
47604747

4761-
// 刷新代理池
47624748
if err := h.store.ReloadProxyPool(); err != nil {
4763-
writeError(c, http.StatusInternalServerError, "代理池刷新失败: "+err.Error())
4764-
return
4749+
log.Printf("代理已添加,但代理池刷新失败: %v", err)
47654750
}
47664751

47674752
c.JSON(http.StatusOK, gin.H{
@@ -4792,8 +4777,7 @@ func (h *Handler) DeleteProxy(c *gin.Context) {
47924777
}
47934778

47944779
if err := h.store.ReloadProxyPool(); err != nil {
4795-
writeError(c, http.StatusInternalServerError, "代理池刷新失败: "+err.Error())
4796-
return
4780+
log.Printf("代理已删除,但代理池刷新失败: %v", err)
47974781
}
47984782
c.JSON(http.StatusOK, gin.H{"message": "代理已删除"})
47994783
}
@@ -4829,8 +4813,7 @@ func (h *Handler) UpdateProxy(c *gin.Context) {
48294813
}
48304814

48314815
if err := h.store.ReloadProxyPool(); err != nil {
4832-
writeError(c, http.StatusInternalServerError, "代理池刷新失败: "+err.Error())
4833-
return
4816+
log.Printf("代理已更新,但代理池刷新失败: %v", err)
48344817
}
48354818
c.JSON(http.StatusOK, gin.H{"message": "代理已更新"})
48364819
}

auth/store.go

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,7 @@ type Store struct {
14561456
tokenCache cache.TokenCache
14571457
apiKeyGroupsMu sync.RWMutex
14581458
apiKeyAllowedGroups map[int64][]int64
1459+
apiKeyAllowedGroupSets map[int64]map[int64]struct{}
14591460
usageProbeMu sync.RWMutex
14601461
usageProbe func(context.Context, *Account) error
14611462
usageProbeBatch atomic.Bool
@@ -3082,14 +3083,20 @@ func (s *Store) SetAPIKeyAllowedGroups(apiKeyID int64, groupIDs []int64) {
30823083
if apiKeyID <= 0 {
30833084
return
30843085
}
3086+
normalized := normalizeAllowedGroupIDs(groupIDs)
30853087
s.apiKeyGroupsMu.Lock()
30863088
if s.apiKeyAllowedGroups == nil {
30873089
s.apiKeyAllowedGroups = make(map[int64][]int64)
30883090
}
3089-
if len(groupIDs) == 0 {
3091+
if s.apiKeyAllowedGroupSets == nil {
3092+
s.apiKeyAllowedGroupSets = make(map[int64]map[int64]struct{})
3093+
}
3094+
if len(normalized) == 0 {
30903095
delete(s.apiKeyAllowedGroups, apiKeyID)
3096+
delete(s.apiKeyAllowedGroupSets, apiKeyID)
30913097
} else {
3092-
s.apiKeyAllowedGroups[apiKeyID] = cloneInt64Slice(groupIDs)
3098+
s.apiKeyAllowedGroups[apiKeyID] = cloneInt64Slice(normalized)
3099+
s.apiKeyAllowedGroupSets[apiKeyID] = int64Set(normalized)
30933100
}
30943101
s.apiKeyGroupsMu.Unlock()
30953102
s.rebuildFastScheduler()
@@ -3114,9 +3121,12 @@ func (s *Store) LoadAPIKeyAllowedGroups(ctx context.Context) error {
31143121
}
31153122
s.apiKeyGroupsMu.Lock()
31163123
s.apiKeyAllowedGroups = make(map[int64][]int64, len(keys))
3124+
s.apiKeyAllowedGroupSets = make(map[int64]map[int64]struct{}, len(keys))
31173125
for _, key := range keys {
3118-
if len(key.AllowedGroupIDs) > 0 {
3119-
s.apiKeyAllowedGroups[key.ID] = cloneInt64Slice(key.AllowedGroupIDs)
3126+
normalized := normalizeAllowedGroupIDs(key.AllowedGroupIDs)
3127+
if len(normalized) > 0 {
3128+
s.apiKeyAllowedGroups[key.ID] = cloneInt64Slice(normalized)
3129+
s.apiKeyAllowedGroupSets[key.ID] = int64Set(normalized)
31203130
}
31213131
}
31223132
s.apiKeyGroupsMu.Unlock()
@@ -3128,14 +3138,12 @@ func (s *Store) APIKeyAllowsAccount(apiKeyID int64, acc *Account) bool {
31283138
if s == nil || apiKeyID <= 0 || acc == nil {
31293139
return true
31303140
}
3131-
allowed := s.GetAPIKeyAllowedGroups(apiKeyID)
3132-
if len(allowed) == 0 {
3141+
s.apiKeyGroupsMu.RLock()
3142+
allowedSet := s.apiKeyAllowedGroupSets[apiKeyID]
3143+
s.apiKeyGroupsMu.RUnlock()
3144+
if len(allowedSet) == 0 {
31333145
return true
31343146
}
3135-
allowedSet := make(map[int64]struct{}, len(allowed))
3136-
for _, id := range allowed {
3137-
allowedSet[id] = struct{}{}
3138-
}
31393147
acc.mu.RLock()
31403148
defer acc.mu.RUnlock()
31413149
for _, id := range acc.GroupIDs {
@@ -3146,6 +3154,31 @@ func (s *Store) APIKeyAllowsAccount(apiKeyID int64, acc *Account) bool {
31463154
return false
31473155
}
31483156

3157+
func normalizeAllowedGroupIDs(groupIDs []int64) []int64 {
3158+
out := make([]int64, 0, len(groupIDs))
3159+
seen := make(map[int64]struct{}, len(groupIDs))
3160+
for _, id := range groupIDs {
3161+
if id <= 0 {
3162+
continue
3163+
}
3164+
if _, ok := seen[id]; ok {
3165+
continue
3166+
}
3167+
seen[id] = struct{}{}
3168+
out = append(out, id)
3169+
}
3170+
sort.Slice(out, func(i, j int) bool { return out[i] < out[j] })
3171+
return out
3172+
}
3173+
3174+
func int64Set(values []int64) map[int64]struct{} {
3175+
out := make(map[int64]struct{}, len(values))
3176+
for _, value := range values {
3177+
out[value] = struct{}{}
3178+
}
3179+
return out
3180+
}
3181+
31493182
func (s *Store) accountAllowedForAPIKey(acc *Account, apiKeyID int64) bool {
31503183
if acc == nil {
31513184
return false

database/account_groups.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@ type AccountGroup struct {
2222
func (db *DB) ListAccountGroups(ctx context.Context) ([]AccountGroup, error) {
2323
rows, err := db.conn.QueryContext(ctx, `
2424
SELECT g.id, g.name, g.description, g.color, g.sort_order,
25-
COALESCE(COUNT(m.account_id), 0), g.created_at, g.updated_at
25+
COALESCE(COUNT(a.id), 0), g.created_at, g.updated_at
2626
FROM account_groups g
2727
LEFT JOIN account_group_members m ON m.group_id = g.id
28+
LEFT JOIN accounts a ON a.id = m.account_id
29+
AND a.status <> 'deleted'
30+
AND COALESCE(a.error_message, '') <> 'deleted'
2831
GROUP BY g.id, g.name, g.description, g.color, g.sort_order, g.created_at, g.updated_at
2932
ORDER BY g.sort_order, g.name`)
3033
if err != nil {
@@ -147,7 +150,12 @@ func (db *DB) DeleteAccountGroup(ctx context.Context, id int64, force ...bool) e
147150
ph = "?"
148151
}
149152
var count int64
150-
if err := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM account_group_members WHERE group_id = "+ph, id).Scan(&count); err != nil {
153+
memberCountQuery := `
154+
SELECT COUNT(*)
155+
FROM account_group_members m
156+
JOIN accounts a ON a.id = m.account_id
157+
WHERE m.group_id = ` + ph + ` AND a.status <> 'deleted' AND COALESCE(a.error_message, '') <> 'deleted'`
158+
if err := tx.QueryRowContext(ctx, memberCountQuery, id).Scan(&count); err != nil {
151159
return err
152160
}
153161
if count > 0 && !allowMembers {

0 commit comments

Comments
 (0)