Skip to content

Commit 28d317c

Browse files
Complete API key and group management polish
1 parent 5b23e67 commit 28d317c

16 files changed

Lines changed: 1077 additions & 198 deletions

File tree

admin/handler.go

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,6 +3281,10 @@ func (h *Handler) CreateAPIKey(c *gin.Context) {
32813281

32823282
type updateAPIKeyReq struct {
32833283
Name *string `json:"name"`
3284+
QuotaLimit json.RawMessage `json:"quota_limit"`
3285+
Quota json.RawMessage `json:"quota"`
3286+
ExpiresAt json.RawMessage `json:"expires_at"`
3287+
ExpiresInDays *int `json:"expires_in_days"`
32843288
AllowedGroupIDs json.RawMessage `json:"allowed_group_ids"`
32853289
}
32863290

@@ -3302,6 +3306,16 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) {
33023306
writeError(c, http.StatusBadRequest, err.Error())
33033307
return
33043308
}
3309+
quotaLimit, quotaLimitSet, err := parseOptionalAPIKeyQuota(req.QuotaLimit, req.Quota)
3310+
if err != nil {
3311+
writeError(c, http.StatusBadRequest, err.Error())
3312+
return
3313+
}
3314+
expiresAt, expiresAtSet, err := parseOptionalAPIKeyExpiration(req.ExpiresAt, req.ExpiresInDays)
3315+
if err != nil {
3316+
writeError(c, http.StatusBadRequest, err.Error())
3317+
return
3318+
}
33053319
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
33063320
defer cancel()
33073321
row, err := h.db.GetAPIKeyByID(ctx, id)
@@ -3327,11 +3341,15 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) {
33273341
writeError(c, http.StatusBadRequest, "名称包含非法字符")
33283342
return
33293343
}
3330-
if err := h.db.UpdateAPIKeyName(ctx, id, name); err != nil {
3331-
writeInternalError(c, err)
3344+
req.Name = &name
3345+
}
3346+
if quotaLimitSet {
3347+
if quotaLimit > 1000000000 {
3348+
writeError(c, http.StatusBadRequest, "额度限制不能超过 1000000000")
33323349
return
33333350
}
33343351
}
3352+
var allowedGroupValues []int64
33353353
if allowedGroupIDs.Set {
33363354
missing, err := h.db.VerifyAccountGroupIDs(ctx, allowedGroupIDs.Values)
33373355
if err != nil {
@@ -3346,19 +3364,71 @@ func (h *Handler) UpdateAPIKey(c *gin.Context) {
33463364
writeError(c, http.StatusBadRequest, "allowed_group_ids 包含不存在的分组 ID: "+strings.Join(values, ", "))
33473365
return
33483366
}
3349-
values := dedupeInt64(allowedGroupIDs.Values)
3350-
if err := h.db.UpdateAPIKeyAllowedGroupIDs(ctx, id, values); err != nil {
3351-
writeInternalError(c, err)
3352-
return
3353-
}
3354-
if h.store != nil {
3355-
h.store.SetAPIKeyAllowedGroups(id, values)
3356-
}
3367+
allowedGroupValues = dedupeInt64(allowedGroupIDs.Values)
3368+
}
3369+
update := database.APIKeyUpdate{
3370+
QuotaLimit: quotaLimit,
3371+
QuotaLimitSet: quotaLimitSet,
3372+
ExpiresAt: expiresAt,
3373+
ExpiresAtSet: expiresAtSet,
3374+
AllowedGroupIDs: allowedGroupValues,
3375+
AllowedGroupIDsSet: allowedGroupIDs.Set,
3376+
}
3377+
if req.Name != nil {
3378+
update.Name = *req.Name
3379+
update.NameSet = true
3380+
}
3381+
if err := h.db.UpdateAPIKey(ctx, id, update); err != nil {
3382+
writeInternalError(c, err)
3383+
return
3384+
}
3385+
if allowedGroupIDs.Set && h.store != nil {
3386+
h.store.SetAPIKeyAllowedGroups(id, allowedGroupValues)
33573387
}
33583388
h.invalidateAPIKeyRuntimeCaches(ctx, row.Key)
33593389
writeMessage(c, http.StatusOK, "API Key 已更新")
33603390
}
33613391

3392+
func parseOptionalAPIKeyQuota(quotaLimitRaw, quotaRaw json.RawMessage) (float64, bool, error) {
3393+
raw := quotaLimitRaw
3394+
if len(raw) == 0 {
3395+
raw = quotaRaw
3396+
}
3397+
if len(raw) == 0 {
3398+
return 0, false, nil
3399+
}
3400+
if bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
3401+
return 0, true, nil
3402+
}
3403+
var value float64
3404+
if err := json.Unmarshal(raw, &value); err != nil {
3405+
return 0, true, fmt.Errorf("额度限制必须是数字")
3406+
}
3407+
if value < 0 {
3408+
return 0, true, fmt.Errorf("额度限制不能小于 0")
3409+
}
3410+
return value, true, nil
3411+
}
3412+
3413+
func parseOptionalAPIKeyExpiration(raw json.RawMessage, expiresInDays *int) (sql.NullTime, bool, error) {
3414+
if expiresInDays != nil {
3415+
expiresAt, err := parseAPIKeyExpiresAt("", expiresInDays)
3416+
return expiresAt, true, err
3417+
}
3418+
if len(raw) == 0 {
3419+
return sql.NullTime{}, false, nil
3420+
}
3421+
if bytes.Equal(bytes.TrimSpace(raw), []byte("null")) {
3422+
return sql.NullTime{}, true, nil
3423+
}
3424+
var value string
3425+
if err := json.Unmarshal(raw, &value); err != nil {
3426+
return sql.NullTime{}, true, fmt.Errorf("过期时间格式无效")
3427+
}
3428+
expiresAt, err := parseAPIKeyExpiresAt(value, nil)
3429+
return expiresAt, true, err
3430+
}
3431+
33623432
func parseAPIKeyExpiresAt(raw string, expiresInDays *int) (sql.NullTime, error) {
33633433
if expiresInDays != nil {
33643434
if *expiresInDays < 0 {

admin/handler_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@ import (
1212
"path/filepath"
1313
"strings"
1414
"testing"
15+
"time"
1516

1617
"github.com/codex2api/auth"
18+
"github.com/codex2api/cache"
1719
"github.com/codex2api/database"
1820
"github.com/gin-gonic/gin"
1921
)
@@ -185,6 +187,124 @@ func TestCreateAPIKeyPersistsQuotaAndExpiration(t *testing.T) {
185187
}
186188
}
187189

190+
func TestUpdateAPIKeyPreservesOmittedFieldsAndUpdatesLimits(t *testing.T) {
191+
gin.SetMode(gin.TestMode)
192+
193+
dbPath := filepath.Join(t.TempDir(), "codex2api.db")
194+
db, err := database.New("sqlite", dbPath)
195+
if err != nil {
196+
t.Fatalf("database.New 返回错误: %v", err)
197+
}
198+
defer db.Close()
199+
200+
expiresAt := sql.NullTime{Time: time.Now().AddDate(0, 0, 3), Valid: true}
201+
id, err := db.InsertAPIKeyWithOptions(context.Background(), database.APIKeyInput{
202+
Name: "Client A",
203+
Key: "sk-test-update-client-1234567890",
204+
QuotaLimit: 0.25,
205+
ExpiresAt: expiresAt,
206+
})
207+
if err != nil {
208+
t.Fatalf("InsertAPIKeyWithOptions 返回错误: %v", err)
209+
}
210+
211+
handler := &Handler{db: db}
212+
recorder := httptest.NewRecorder()
213+
ctx, _ := gin.CreateTestContext(recorder)
214+
ctx.Params = gin.Params{{Key: "id", Value: fmt.Sprintf("%d", id)}}
215+
ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/admin/keys/1", strings.NewReader(`{"name":"Client B"}`))
216+
ctx.Request.Header.Set("Content-Type", "application/json")
217+
218+
handler.UpdateAPIKey(ctx)
219+
220+
if recorder.Code != http.StatusOK {
221+
t.Fatalf("status = %d, want %d, body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
222+
}
223+
row, err := db.GetAPIKeyByID(context.Background(), id)
224+
if err != nil {
225+
t.Fatalf("GetAPIKeyByID 返回错误: %v", err)
226+
}
227+
if row.Name != "Client B" || row.QuotaLimit != 0.25 || !row.ExpiresAt.Valid {
228+
t.Fatalf("row = %#v, want renamed with quota/expiration preserved", row)
229+
}
230+
231+
recorder = httptest.NewRecorder()
232+
ctx, _ = gin.CreateTestContext(recorder)
233+
ctx.Params = gin.Params{{Key: "id", Value: fmt.Sprintf("%d", id)}}
234+
ctx.Request = httptest.NewRequest(http.MethodPatch, "/api/admin/keys/1", strings.NewReader(`{"quota_limit":0,"expires_at":null}`))
235+
ctx.Request.Header.Set("Content-Type", "application/json")
236+
237+
handler.UpdateAPIKey(ctx)
238+
239+
if recorder.Code != http.StatusOK {
240+
t.Fatalf("status = %d, want %d, body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
241+
}
242+
row, err = db.GetAPIKeyByID(context.Background(), id)
243+
if err != nil {
244+
t.Fatalf("GetAPIKeyByID 返回错误: %v", err)
245+
}
246+
if row.Name != "Client B" || row.QuotaLimit != 0 || row.ExpiresAt.Valid {
247+
t.Fatalf("row = %#v, want quota/expiration cleared with name preserved", row)
248+
}
249+
}
250+
251+
func TestUpdateAPIKeyRefreshesRuntimeStoreAndCache(t *testing.T) {
252+
gin.SetMode(gin.TestMode)
253+
254+
dbPath := filepath.Join(t.TempDir(), "codex2api.db")
255+
db, err := database.New("sqlite", dbPath)
256+
if err != nil {
257+
t.Fatalf("database.New 返回错误: %v", err)
258+
}
259+
defer db.Close()
260+
261+
ctx := context.Background()
262+
groupID, err := db.CreateAccountGroup(ctx, "Team", "", "#2563eb", 0)
263+
if err != nil {
264+
t.Fatalf("CreateAccountGroup 返回错误: %v", err)
265+
}
266+
key := "sk-test-runtime-refresh-1234567890"
267+
keyID, err := db.InsertAPIKey(ctx, "Client A", key)
268+
if err != nil {
269+
t.Fatalf("InsertAPIKey 返回错误: %v", err)
270+
}
271+
store := auth.NewStore(nil, nil, nil)
272+
tc := cache.NewMemory(1)
273+
handler := &Handler{db: db, store: store, cache: tc}
274+
payload, err := json.Marshal(map[string]interface{}{
275+
"id": keyID,
276+
"name": "Client A",
277+
"created_at": time.Now().UTC(),
278+
})
279+
if err != nil {
280+
t.Fatalf("marshal runtime cache: %v", err)
281+
}
282+
if err := tc.SetRuntime(ctx, adminAPIKeyCacheNamespace, key, payload, time.Minute); err != nil {
283+
t.Fatalf("SetRuntime api key: %v", err)
284+
}
285+
286+
recorder := httptest.NewRecorder()
287+
ginCtx, _ := gin.CreateTestContext(recorder)
288+
ginCtx.Params = gin.Params{{Key: "id", Value: fmt.Sprintf("%d", keyID)}}
289+
ginCtx.Request = httptest.NewRequest(http.MethodPatch, "/api/admin/keys/1", strings.NewReader(fmt.Sprintf(`{"allowed_group_ids":[%d]}`, groupID)))
290+
ginCtx.Request.Header.Set("Content-Type", "application/json")
291+
292+
handler.UpdateAPIKey(ginCtx)
293+
294+
if recorder.Code != http.StatusOK {
295+
t.Fatalf("status = %d, want %d, body=%s", recorder.Code, http.StatusOK, recorder.Body.String())
296+
}
297+
if got := store.GetAPIKeyAllowedGroups(keyID); len(got) != 1 || got[0] != groupID {
298+
t.Fatalf("runtime store allowed groups = %v, want [%d]", got, groupID)
299+
}
300+
if _, ok, err := tc.GetRuntime(ctx, adminAPIKeyCacheNamespace, key); err != nil || ok {
301+
t.Fatalf("runtime api key cache after update ok=%v err=%v, want miss", ok, err)
302+
}
303+
if _, ok, err := tc.GetRuntime(ctx, adminAPIKeyCountNamespace, "all"); err != nil || ok {
304+
t.Fatalf("runtime api key count cache after update ok=%v err=%v, want miss", ok, err)
305+
}
306+
}
307+
188308
func TestGetAccountAuthJSONRejectsInvalidID(t *testing.T) {
189309
gin.SetMode(gin.TestMode)
190310

database/postgres.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,11 @@ func (db *DB) migrate(ctx context.Context) error {
472472
created_at TIMESTAMPTZ DEFAULT NOW(),
473473
updated_at TIMESTAMPTZ DEFAULT NOW()
474474
);
475+
ALTER TABLE account_groups ADD COLUMN IF NOT EXISTS description TEXT DEFAULT '';
476+
ALTER TABLE account_groups ADD COLUMN IF NOT EXISTS color VARCHAR(20) DEFAULT '';
477+
ALTER TABLE account_groups ADD COLUMN IF NOT EXISTS sort_order INT DEFAULT 0;
478+
ALTER TABLE account_groups ADD COLUMN IF NOT EXISTS created_at TIMESTAMPTZ DEFAULT NOW();
479+
ALTER TABLE account_groups ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ DEFAULT NOW();
475480
476481
CREATE TABLE IF NOT EXISTS account_group_members (
477482
account_id BIGINT NOT NULL,
@@ -796,6 +801,17 @@ type APIKeyInput struct {
796801
AllowedGroupIDs []int64
797802
}
798803

804+
type APIKeyUpdate struct {
805+
Name string
806+
NameSet bool
807+
QuotaLimit float64
808+
QuotaLimitSet bool
809+
ExpiresAt sql.NullTime
810+
ExpiresAtSet bool
811+
AllowedGroupIDs []int64
812+
AllowedGroupIDsSet bool
813+
}
814+
799815
const apiKeySelectColumns = `id, name, key, created_at, COALESCE(quota_limit, 0), COALESCE(quota_used, 0), expires_at, COALESCE(allowed_group_ids, '[]')`
800816

801817
// ListAPIKeys 获取所有 API 密钥
@@ -893,6 +909,41 @@ func (db *DB) UpdateAPIKeyName(ctx context.Context, id int64, name string) error
893909
return nil
894910
}
895911

912+
// UpdateAPIKeyQuotaLimit updates the quota ceiling. A non-positive value clears the limit.
913+
func (db *DB) UpdateAPIKeyQuotaLimit(ctx context.Context, id int64, quotaLimit float64) error {
914+
if quotaLimit < 0 {
915+
quotaLimit = 0
916+
}
917+
res, err := db.conn.ExecContext(ctx, `UPDATE api_keys SET quota_limit = $1 WHERE id = $2`, quotaLimit, id)
918+
if err != nil {
919+
return err
920+
}
921+
affected, err := res.RowsAffected()
922+
if err != nil {
923+
return err
924+
}
925+
if affected == 0 {
926+
return sql.ErrNoRows
927+
}
928+
return nil
929+
}
930+
931+
// UpdateAPIKeyExpiresAt updates or clears the key expiration.
932+
func (db *DB) UpdateAPIKeyExpiresAt(ctx context.Context, id int64, expiresAt sql.NullTime) error {
933+
res, err := db.conn.ExecContext(ctx, `UPDATE api_keys SET expires_at = $1 WHERE id = $2`, nullableTimeArg(expiresAt), id)
934+
if err != nil {
935+
return err
936+
}
937+
affected, err := res.RowsAffected()
938+
if err != nil {
939+
return err
940+
}
941+
if affected == 0 {
942+
return sql.ErrNoRows
943+
}
944+
return nil
945+
}
946+
896947
// UpdateAPIKeyAllowedGroups persists the allowed-group scope for an API key.
897948
// Empty slice clears the scope (key may schedule any account).
898949
func (db *DB) UpdateAPIKeyAllowedGroups(ctx context.Context, id int64, groupIDs []int64) error {
@@ -923,6 +974,64 @@ func (db *DB) UpdateAPIKeyAllowedGroupIDs(ctx context.Context, id int64, groupID
923974
return db.UpdateAPIKeyAllowedGroups(ctx, id, groupIDs)
924975
}
925976

977+
// UpdateAPIKey applies multiple editable fields in one transaction.
978+
// Omitted fields keep their existing values.
979+
func (db *DB) UpdateAPIKey(ctx context.Context, id int64, update APIKeyUpdate) error {
980+
sets := make([]string, 0, 4)
981+
args := make([]interface{}, 0, 5)
982+
placeholder := func() string {
983+
args = append(args, nil)
984+
if db.isSQLite() {
985+
return "?"
986+
}
987+
return fmt.Sprintf("$%d", len(args))
988+
}
989+
setArg := func(value interface{}) string {
990+
ph := placeholder()
991+
args[len(args)-1] = value
992+
return ph
993+
}
994+
if update.NameSet {
995+
sets = append(sets, "name = "+setArg(update.Name))
996+
}
997+
if update.QuotaLimitSet {
998+
quotaLimit := update.QuotaLimit
999+
if quotaLimit < 0 {
1000+
quotaLimit = 0
1001+
}
1002+
sets = append(sets, "quota_limit = "+setArg(quotaLimit))
1003+
}
1004+
if update.ExpiresAtSet {
1005+
sets = append(sets, "expires_at = "+setArg(nullableTimeArg(update.ExpiresAt)))
1006+
}
1007+
if update.AllowedGroupIDsSet {
1008+
payload := encodeInt64SliceJSON(update.AllowedGroupIDs)
1009+
ph := setArg(payload)
1010+
if db.isSQLite() {
1011+
sets = append(sets, "allowed_group_ids = "+ph)
1012+
} else {
1013+
sets = append(sets, "allowed_group_ids = "+ph+"::jsonb")
1014+
}
1015+
}
1016+
if len(sets) == 0 {
1017+
return nil
1018+
}
1019+
idPlaceholder := placeholder()
1020+
args[len(args)-1] = id
1021+
res, err := db.conn.ExecContext(ctx, "UPDATE api_keys SET "+strings.Join(sets, ", ")+" WHERE id = "+idPlaceholder, args...)
1022+
if err != nil {
1023+
return err
1024+
}
1025+
affected, err := res.RowsAffected()
1026+
if err != nil {
1027+
return err
1028+
}
1029+
if affected == 0 {
1030+
return sql.ErrNoRows
1031+
}
1032+
return nil
1033+
}
1034+
9261035
// ==================== System Settings ====================
9271036

9281037
const DefaultSiteName = "CodexProxy"

0 commit comments

Comments
 (0)