Skip to content

Commit 5b23e67

Browse files
Complete account group permission management
1 parent f9b5bf9 commit 5b23e67

11 files changed

Lines changed: 619 additions & 85 deletions

File tree

admin/account_groups.go

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -201,35 +201,27 @@ func (h *Handler) DeleteAccountGroup(c *gin.Context) {
201201
h.store.ApplyAccountGroups(acc.DBID, groups)
202202
}
203203
}
204-
if err := h.removeDeletedGroupFromAPIKeyScopes(ctx, id); err != nil {
205-
writeInternalError(c, err)
206-
return
207-
}
204+
h.refreshAPIKeyAllowedGroupsAfterGroupDelete(ctx, id)
208205
writeMessage(c, http.StatusOK, "分组已删除")
209206
}
210207

211-
func (h *Handler) removeDeletedGroupFromAPIKeyScopes(ctx context.Context, groupID int64) error {
208+
func (h *Handler) refreshAPIKeyAllowedGroupsAfterGroupDelete(ctx context.Context, groupID int64) {
212209
if h == nil || h.db == nil || groupID <= 0 {
213-
return nil
210+
return
214211
}
215212
keys, err := h.db.ListAPIKeys(ctx)
216213
if err != nil {
217-
return err
214+
return
218215
}
219216
for _, key := range keys {
220-
if key == nil || !containsInt64(key.AllowedGroupIDs, groupID) {
217+
if key == nil {
221218
continue
222219
}
223-
next := removeInt64(key.AllowedGroupIDs, groupID)
224-
if err := h.db.UpdateAPIKeyAllowedGroupIDs(ctx, key.ID, next); err != nil {
225-
return err
226-
}
227220
if h.store != nil {
228-
h.store.SetAPIKeyAllowedGroups(key.ID, next)
221+
h.store.SetAPIKeyAllowedGroups(key.ID, key.AllowedGroupIDs)
229222
}
230223
h.invalidateAPIKeyRuntimeCaches(ctx, key.Key)
231224
}
232-
return nil
233225
}
234226

235227
func sanitizeAccountGroupName(raw string) (string, error) {

admin/handler.go

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,30 @@ func (h *Handler) UpdateAccountScheduler(c *gin.Context) {
729729
return
730730
}
731731
if h.store != nil {
732-
h.store.ApplyAccountSchedulerOverrides(id, nullableInt64Pointer(scoreBiasOverride), nullableInt64Pointer(baseConcurrencyOverride))
732+
if scoreBiasOverride.Set || baseConcurrencyOverride.Set {
733+
current := h.store.FindByID(id)
734+
var score *int64
735+
var concurrency *int64
736+
if current != nil {
737+
current.Mu().RLock()
738+
if current.ScoreBiasOverride != nil {
739+
value := *current.ScoreBiasOverride
740+
score = &value
741+
}
742+
if current.BaseConcurrencyOverride != nil {
743+
value := *current.BaseConcurrencyOverride
744+
concurrency = &value
745+
}
746+
current.Mu().RUnlock()
747+
}
748+
if scoreBiasOverride.Set {
749+
score = nullableInt64Pointer(scoreBiasOverride.Value)
750+
}
751+
if baseConcurrencyOverride.Set {
752+
concurrency = nullableInt64Pointer(baseConcurrencyOverride.Value)
753+
}
754+
h.store.ApplyAccountSchedulerOverrides(id, score, concurrency)
755+
}
733756
if allowedAPIKeyIDs.Set {
734757
h.store.ApplyAccountAllowedAPIKeys(id, allowedAPIKeyIDs.Values)
735758
}
@@ -804,23 +827,26 @@ func parseOptionalStringSliceField(raw json.RawMessage, field string) (optionalS
804827
return optionalStringSlice{Set: true, Values: out}, nil
805828
}
806829

807-
func parseOptionalIntegerField(raw json.RawMessage, field string, minValue, maxValue int64) (sql.NullInt64, error) {
808-
if len(raw) == 0 || string(raw) == "null" {
809-
return sql.NullInt64{}, nil
830+
func parseOptionalIntegerField(raw json.RawMessage, field string, minValue, maxValue int64) (database.OptionalNullInt64, error) {
831+
if len(raw) == 0 {
832+
return database.OptionalNullInt64{}, nil
833+
}
834+
if string(raw) == "null" {
835+
return database.OptionalNullInt64{Set: true}, nil
810836
}
811837

812838
var number json.Number
813839
if err := json.Unmarshal(raw, &number); err != nil {
814-
return sql.NullInt64{}, fmt.Errorf("%s 必须是整数或 null", field)
840+
return database.OptionalNullInt64{}, fmt.Errorf("%s 必须是整数或 null", field)
815841
}
816842
value, err := number.Int64()
817843
if err != nil {
818-
return sql.NullInt64{}, fmt.Errorf("%s 必须是整数或 null", field)
844+
return database.OptionalNullInt64{}, fmt.Errorf("%s 必须是整数或 null", field)
819845
}
820846
if value < minValue || value > maxValue {
821-
return sql.NullInt64{}, fmt.Errorf("%s 超出范围,必须在 %d..%d 之间", field, minValue, maxValue)
847+
return database.OptionalNullInt64{}, fmt.Errorf("%s 超出范围,必须在 %d..%d 之间", field, minValue, maxValue)
822848
}
823-
return sql.NullInt64{Int64: value, Valid: true}, nil
849+
return database.OptionalNullInt64{Set: true, Value: sql.NullInt64{Int64: value, Valid: true}}, nil
824850
}
825851

826852
func parseOptionalIntegerSliceField(raw json.RawMessage, field string) (database.OptionalInt64Slice, error) {
@@ -3117,12 +3143,13 @@ func (h *Handler) ListAPIKeys(c *gin.Context) {
31173143
}
31183144

31193145
type createKeyReq struct {
3120-
Name string `json:"name"`
3121-
Key string `json:"key"`
3122-
QuotaLimit *float64 `json:"quota_limit"`
3123-
Quota *float64 `json:"quota"`
3124-
ExpiresAt string `json:"expires_at"`
3125-
ExpiresInDays *int `json:"expires_in_days"`
3146+
Name string `json:"name"`
3147+
Key string `json:"key"`
3148+
QuotaLimit *float64 `json:"quota_limit"`
3149+
Quota *float64 `json:"quota"`
3150+
ExpiresAt string `json:"expires_at"`
3151+
ExpiresInDays *int `json:"expires_in_days"`
3152+
AllowedGroupIDs json.RawMessage `json:"allowed_group_ids"`
31263153
}
31273154

31283155
// generateKey 生成随机 API Key
@@ -3178,6 +3205,11 @@ func (h *Handler) CreateAPIKey(c *gin.Context) {
31783205
writeError(c, http.StatusBadRequest, err.Error())
31793206
return
31803207
}
3208+
allowedGroupIDs, err := parseOptionalIntegerSliceField(req.AllowedGroupIDs, "allowed_group_ids")
3209+
if err != nil {
3210+
writeError(c, http.StatusBadRequest, err.Error())
3211+
return
3212+
}
31813213

31823214
key := req.Key
31833215
if key == "" {
@@ -3193,17 +3225,39 @@ func (h *Handler) CreateAPIKey(c *gin.Context) {
31933225

31943226
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
31953227
defer cancel()
3228+
if allowedGroupIDs.Set {
3229+
missing, err := h.db.VerifyAccountGroupIDs(ctx, allowedGroupIDs.Values)
3230+
if err != nil {
3231+
writeInternalError(c, err)
3232+
return
3233+
}
3234+
if len(missing) > 0 {
3235+
values := make([]string, 0, len(missing))
3236+
for _, value := range missing {
3237+
values = append(values, strconv.FormatInt(value, 10))
3238+
}
3239+
writeError(c, http.StatusBadRequest, "allowed_group_ids 包含不存在的分组 ID: "+strings.Join(values, ", "))
3240+
return
3241+
}
3242+
}
31963243

31973244
id, err := h.db.InsertAPIKeyWithOptions(ctx, database.APIKeyInput{
3198-
Name: req.Name,
3199-
Key: key,
3200-
QuotaLimit: quotaLimit,
3201-
ExpiresAt: expiresAt,
3245+
Name: req.Name,
3246+
Key: key,
3247+
QuotaLimit: quotaLimit,
3248+
ExpiresAt: expiresAt,
3249+
AllowedGroupIDs: allowedGroupIDs.Values,
32023250
})
32033251
if err != nil {
32043252
writeError(c, http.StatusInternalServerError, "创建失败: "+err.Error())
32053253
return
32063254
}
3255+
if allowedGroupIDs.Set {
3256+
values := dedupeInt64(allowedGroupIDs.Values)
3257+
if h.store != nil {
3258+
h.store.SetAPIKeyAllowedGroups(id, values)
3259+
}
3260+
}
32073261
h.invalidateAPIKeyRuntimeCaches(ctx, key)
32083262

32093263
// 记录安全审计日志
@@ -3215,12 +3269,13 @@ func (h *Handler) CreateAPIKey(c *gin.Context) {
32153269
expiresAtResponse = &formatted
32163270
}
32173271
c.JSON(http.StatusOK, createAPIKeyResponse{
3218-
ID: id,
3219-
Key: key,
3220-
Name: req.Name,
3221-
QuotaLimit: quotaLimit,
3222-
QuotaUsed: 0,
3223-
ExpiresAt: expiresAtResponse,
3272+
ID: id,
3273+
Key: key,
3274+
Name: req.Name,
3275+
QuotaLimit: quotaLimit,
3276+
QuotaUsed: 0,
3277+
ExpiresAt: expiresAtResponse,
3278+
AllowedGroupIDs: dedupeInt64(allowedGroupIDs.Values),
32243279
})
32253280
}
32263281

admin/handler_test.go

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ func TestUpdateAccountSchedulerResetsToAutoOnNull(t *testing.T) {
491491
db := newTestAdminDB(t)
492492
accountID := insertTestAccount(t, db)
493493
ctx := context.Background()
494-
if err := db.UpdateAccountSchedulerConfig(ctx, accountID, sql.NullInt64{Int64: 20, Valid: true}, sql.NullInt64{Int64: 4, Valid: true}, database.OptionalInt64Slice{}); err != nil {
494+
if err := db.UpdateAccountSchedulerConfig(ctx, accountID, database.OptionalNullInt64{Set: true, Value: sql.NullInt64{Int64: 20, Valid: true}}, database.OptionalNullInt64{Set: true, Value: sql.NullInt64{Int64: 4, Valid: true}}, database.OptionalInt64Slice{}); err != nil {
495495
t.Fatalf("seed scheduler config: %v", err)
496496
}
497497

@@ -523,6 +523,49 @@ func TestUpdateAccountSchedulerResetsToAutoOnNull(t *testing.T) {
523523
}
524524
}
525525

526+
func TestUpdateAccountSchedulerPartialMetadataPatchPreservesSchedulerConfig(t *testing.T) {
527+
gin.SetMode(gin.TestMode)
528+
529+
db := newTestAdminDB(t)
530+
accountID := insertTestAccount(t, db)
531+
keyID := insertTestAPIKey(t, db, "Team A")
532+
ctx := context.Background()
533+
if err := db.UpdateAccountSchedulerConfig(ctx, accountID,
534+
database.OptionalNullInt64{Set: true, Value: sql.NullInt64{Int64: 20, Valid: true}},
535+
database.OptionalNullInt64{Set: true, Value: sql.NullInt64{Int64: 4, Valid: true}},
536+
database.OptionalInt64Slice{Set: true, Values: []int64{keyID}},
537+
); err != nil {
538+
t.Fatalf("seed scheduler config: %v", err)
539+
}
540+
541+
handler := &Handler{db: db}
542+
recorder := httptest.NewRecorder()
543+
ginCtx, _ := gin.CreateTestContext(recorder)
544+
ginCtx.Params = gin.Params{{Key: "id", Value: fmt.Sprintf("%d", accountID)}}
545+
ginCtx.Request = httptest.NewRequest(http.MethodPatch, fmt.Sprintf("/api/admin/accounts/%d/scheduler", accountID), strings.NewReader(`{"tags":["ops"]}`))
546+
ginCtx.Request.Header.Set("Content-Type", "application/json")
547+
548+
handler.UpdateAccountScheduler(ginCtx)
549+
550+
if recorder.Code != http.StatusOK {
551+
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusOK)
552+
}
553+
554+
rows, err := db.ListActive(context.Background())
555+
if err != nil {
556+
t.Fatalf("ListActive: %v", err)
557+
}
558+
if !rows[0].ScoreBiasOverride.Valid || rows[0].ScoreBiasOverride.Int64 != 20 {
559+
t.Fatalf("score_bias_override = %+v, want 20", rows[0].ScoreBiasOverride)
560+
}
561+
if !rows[0].BaseConcurrencyOverride.Valid || rows[0].BaseConcurrencyOverride.Int64 != 4 {
562+
t.Fatalf("base_concurrency_override = %+v, want 4", rows[0].BaseConcurrencyOverride)
563+
}
564+
if got := rows[0].GetCredentialInt64Slice("allowed_api_key_ids"); len(got) != 1 || got[0] != keyID {
565+
t.Fatalf("allowed_api_key_ids = %v, want [%d]", got, keyID)
566+
}
567+
}
568+
526569
func TestUpdateAccountSchedulerClearsAllowedAPIKeyIDsOnNull(t *testing.T) {
527570
gin.SetMode(gin.TestMode)
528571

admin/responses.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,13 @@ func NewMaskedAPIKeyRow(row *database.APIKeyRow) *MaskedAPIKeyRow {
8989
}
9090

9191
type createAPIKeyResponse struct {
92-
ID int64 `json:"id"`
93-
Key string `json:"key"`
94-
Name string `json:"name"`
95-
QuotaLimit float64 `json:"quota_limit"`
96-
QuotaUsed float64 `json:"quota_used"`
97-
ExpiresAt *string `json:"expires_at"`
92+
ID int64 `json:"id"`
93+
Key string `json:"key"`
94+
Name string `json:"name"`
95+
QuotaLimit float64 `json:"quota_limit"`
96+
QuotaUsed float64 `json:"quota_used"`
97+
ExpiresAt *string `json:"expires_at"`
98+
AllowedGroupIDs []int64 `json:"allowed_group_ids"`
9899
}
99100

100101
type opsOverviewResponse struct {

database/account_groups.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,9 @@ func (db *DB) DeleteAccountGroup(ctx context.Context, id int64, force ...bool) e
156156
if _, err := tx.ExecContext(ctx, "DELETE FROM account_group_members WHERE group_id = "+ph, id); err != nil {
157157
return err
158158
}
159+
if err := pruneDeletedGroupFromAPIKeyScopes(ctx, tx, db.isSQLite(), id); err != nil {
160+
return err
161+
}
159162
res, err := tx.ExecContext(ctx, "DELETE FROM account_groups WHERE id = "+ph, id)
160163
if err != nil {
161164
return err
@@ -170,6 +173,65 @@ func (db *DB) DeleteAccountGroup(ctx context.Context, id int64, force ...bool) e
170173
return tx.Commit()
171174
}
172175

176+
func pruneDeletedGroupFromAPIKeyScopes(ctx context.Context, tx *sql.Tx, sqlite bool, groupID int64) error {
177+
rows, err := tx.QueryContext(ctx, `SELECT id, COALESCE(allowed_group_ids, '[]') FROM api_keys`)
178+
if err != nil {
179+
return err
180+
}
181+
defer rows.Close()
182+
183+
type update struct {
184+
id int64
185+
groups []int64
186+
}
187+
updates := make([]update, 0)
188+
for rows.Next() {
189+
var id int64
190+
var raw interface{}
191+
if err := rows.Scan(&id, &raw); err != nil {
192+
return err
193+
}
194+
groups := decodeInt64SliceValue(raw)
195+
if !containsInt64(groups, groupID) {
196+
continue
197+
}
198+
updates = append(updates, update{id: id, groups: removeInt64(groups, groupID)})
199+
}
200+
if err := rows.Err(); err != nil {
201+
return err
202+
}
203+
204+
query := `UPDATE api_keys SET allowed_group_ids = $1::jsonb WHERE id = $2`
205+
if sqlite {
206+
query = `UPDATE api_keys SET allowed_group_ids = ? WHERE id = ?`
207+
}
208+
for _, item := range updates {
209+
if _, err := tx.ExecContext(ctx, query, encodeInt64SliceJSON(item.groups), item.id); err != nil {
210+
return err
211+
}
212+
}
213+
return nil
214+
}
215+
216+
func removeInt64(slice []int64, target int64) []int64 {
217+
out := make([]int64, 0, len(slice))
218+
for _, v := range slice {
219+
if v != target {
220+
out = append(out, v)
221+
}
222+
}
223+
return out
224+
}
225+
226+
func containsInt64(slice []int64, target int64) bool {
227+
for _, v := range slice {
228+
if v == target {
229+
return true
230+
}
231+
}
232+
return false
233+
}
234+
173235
func (db *DB) SetAccountGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
174236
tx, err := db.conn.BeginTx(ctx, nil)
175237
if err != nil {

0 commit comments

Comments
 (0)