diff --git a/backend/internal/repository/usage_log_query.go b/backend/internal/repository/usage_log_query.go new file mode 100644 index 00000000000..f14c24bb0f5 --- /dev/null +++ b/backend/internal/repository/usage_log_query.go @@ -0,0 +1,547 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "time" + + dbaccount "github.com/Wei-Shaw/sub2api/ent/account" + dbapikey "github.com/Wei-Shaw/sub2api/ent/apikey" + dbgroup "github.com/Wei-Shaw/sub2api/ent/group" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" + rows, err := r.sql.QueryContext(ctx, query, id) + if err != nil { + return nil, err + } + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + log = nil + } + }() + if !rows.Next() { + if err = rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUsageLogNotFound + } + log, err = scanUsageLog(rows) + if err != nil { + return nil, err + } + if err = rows.Err(); err != nil { + return nil, err + } + return log, nil +} + +func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params) +} + +func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params) +} + +func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + return r.listUsageLogsWithPagination(ctx, "WHERE account_id = $1", []any{accountID}, params) +} + +func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" + logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) + return logs, nil, err +} + +func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" + logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) + return logs, nil, err +} + +func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" + logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) + return logs, nil, err +} + +func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { + query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn) + logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) + return logs, nil, err +} + +// ListWithFilters lists usage logs with optional filters (for admin) +func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + conditions := make([]string, 0, 9) + args := make([]any, 0, 9) + + if filters.UserID > 0 { + conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) + args = append(args, filters.UserID) + } + if filters.APIKeyID > 0 { + conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1)) + args = append(args, filters.APIKeyID) + } + if filters.AccountID > 0 { + conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1)) + args = append(args, filters.AccountID) + } + if filters.GroupID > 0 { + conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) + args = append(args, filters.GroupID) + } + conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) + if filters.BillingType != nil { + conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) + args = append(args, int16(*filters.BillingType)) + } + if filters.BillingMode != "" { + conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) + args = append(args, filters.BillingMode) + } + if filters.StartTime != nil { + conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) + args = append(args, *filters.StartTime) + } + if filters.EndTime != nil { + conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1)) + args = append(args, *filters.EndTime) + } + + whereClause := buildWhere(conditions) + var ( + logs []service.UsageLog + page *pagination.PaginationResult + err error + ) + if shouldUseFastUsageLogTotal(filters) { + logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params) + } else { + logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params) + } + if err != nil { + return nil, nil, err + } + + if err := r.hydrateUsageLogAssociations(ctx, logs); err != nil { + return nil, nil, err + } + return logs, page, nil +} + +func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause + var total int64 + if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil { + return nil, nil, err + } + + limitPos := len(args) + 1 + offsetPos := len(args) + 2 + listArgs := append(append([]any{}, args...), params.Limit(), params.Offset()) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos) + logs, err := r.queryUsageLogs(ctx, query, listArgs...) + if err != nil { + return nil, nil, err + } + return logs, paginationResultFromTotal(total, params), nil +} + +func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { + limit := params.Limit() + offset := params.Offset() + + limitPos := len(args) + 1 + offsetPos := len(args) + 2 + listArgs := append(append([]any{}, args...), limit+1, offset) + query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos) + + logs, err := r.queryUsageLogs(ctx, query, listArgs...) + if err != nil { + return nil, nil, err + } + + hasMore := false + if len(logs) > limit { + hasMore = true + logs = logs[:limit] + } + + total := int64(offset) + int64(len(logs)) + if hasMore { + // 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。 + total = int64(offset) + int64(limit) + 1 + } + + return logs, paginationResultFromTotal(total, params), nil +} + +func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + logs = nil + } + }() + + logs = make([]service.UsageLog, 0) + for rows.Next() { + var log *service.UsageLog + log, err = scanUsageLog(rows) + if err != nil { + return nil, err + } + logs = append(logs, *log) + } + if err = rows.Err(); err != nil { + return nil, err + } + return logs, nil +} + +func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, logs []service.UsageLog) error { + // 关联数据使用 Ent 批量加载,避免把复杂 SQL 继续膨胀。 + if len(logs) == 0 { + return nil + } + + ids := collectUsageLogIDs(logs) + users, err := r.loadUsers(ctx, ids.userIDs) + if err != nil { + return err + } + apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs) + if err != nil { + return err + } + accounts, err := r.loadAccounts(ctx, ids.accountIDs) + if err != nil { + return err + } + groups, err := r.loadGroups(ctx, ids.groupIDs) + if err != nil { + return err + } + subs, err := r.loadSubscriptions(ctx, ids.subscriptionIDs) + if err != nil { + return err + } + + for i := range logs { + if user, ok := users[logs[i].UserID]; ok { + logs[i].User = user + } + if key, ok := apiKeys[logs[i].APIKeyID]; ok { + logs[i].APIKey = key + } + if acc, ok := accounts[logs[i].AccountID]; ok { + logs[i].Account = acc + } + if logs[i].GroupID != nil { + if group, ok := groups[*logs[i].GroupID]; ok { + logs[i].Group = group + } + } + if logs[i].SubscriptionID != nil { + if sub, ok := subs[*logs[i].SubscriptionID]; ok { + logs[i].Subscription = sub + } + } + } + return nil +} + +func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[int64]*service.User, error) { + out := make(map[int64]*service.User) + if len(ids) == 0 { + return out, nil + } + models, err := r.client.User.Query().Where(dbuser.IDIn(ids...)).All(ctx) + if err != nil { + return nil, err + } + for _, m := range models { + out[m.ID] = userEntityToService(m) + } + return out, nil +} + +func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) { + out := make(map[int64]*service.APIKey) + if len(ids) == 0 { + return out, nil + } + models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx) + if err != nil { + return nil, err + } + for _, m := range models { + out[m.ID] = apiKeyEntityToService(m) + } + return out, nil +} + +func (r *usageLogRepository) loadAccounts(ctx context.Context, ids []int64) (map[int64]*service.Account, error) { + out := make(map[int64]*service.Account) + if len(ids) == 0 { + return out, nil + } + models, err := r.client.Account.Query().Where(dbaccount.IDIn(ids...)).All(ctx) + if err != nil { + return nil, err + } + for _, m := range models { + out[m.ID] = accountEntityToService(m) + } + return out, nil +} + +func (r *usageLogRepository) loadGroups(ctx context.Context, ids []int64) (map[int64]*service.Group, error) { + out := make(map[int64]*service.Group) + if len(ids) == 0 { + return out, nil + } + models, err := r.client.Group.Query().Where(dbgroup.IDIn(ids...)).All(ctx) + if err != nil { + return nil, err + } + for _, m := range models { + out[m.ID] = groupEntityToService(m) + } + return out, nil +} + +func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64) (map[int64]*service.UserSubscription, error) { + out := make(map[int64]*service.UserSubscription) + if len(ids) == 0 { + return out, nil + } + models, err := r.client.UserSubscription.Query().Where(dbusersub.IDIn(ids...)).All(ctx) + if err != nil { + return nil, err + } + for _, m := range models { + out[m.ID] = userSubscriptionEntityToService(m) + } + return out, nil +} + +func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) { + var ( + id int64 + userID int64 + apiKeyID int64 + accountID int64 + requestID sql.NullString + model string + requestedModel sql.NullString + upstreamModel sql.NullString + groupID sql.NullInt64 + subscriptionID sql.NullInt64 + inputTokens int + outputTokens int + cacheCreationTokens int + cacheReadTokens int + cacheCreation5m int + cacheCreation1h int + imageOutputTokens int + imageOutputCost float64 + inputCost float64 + outputCost float64 + cacheCreationCost float64 + cacheReadCost float64 + totalCost float64 + actualCost float64 + rateMultiplier float64 + accountRateMultiplier sql.NullFloat64 + billingType int16 + requestTypeRaw int16 + stream bool + openaiWSMode bool + durationMs sql.NullInt64 + firstTokenMs sql.NullInt64 + userAgent sql.NullString + ipAddress sql.NullString + imageCount int + imageSize sql.NullString + serviceTier sql.NullString + reasoningEffort sql.NullString + inboundEndpoint sql.NullString + upstreamEndpoint sql.NullString + cacheTTLOverridden bool + channelID sql.NullInt64 + modelMappingChain sql.NullString + billingTier sql.NullString + billingMode sql.NullString + accountStatsCost sql.NullFloat64 + createdAt time.Time + ) + + if err := scanner.Scan( + &id, + &userID, + &apiKeyID, + &accountID, + &requestID, + &model, + &requestedModel, + &upstreamModel, + &groupID, + &subscriptionID, + &inputTokens, + &outputTokens, + &cacheCreationTokens, + &cacheReadTokens, + &cacheCreation5m, + &cacheCreation1h, + &imageOutputTokens, + &imageOutputCost, + &inputCost, + &outputCost, + &cacheCreationCost, + &cacheReadCost, + &totalCost, + &actualCost, + &rateMultiplier, + &accountRateMultiplier, + &billingType, + &requestTypeRaw, + &stream, + &openaiWSMode, + &durationMs, + &firstTokenMs, + &userAgent, + &ipAddress, + &imageCount, + &imageSize, + &serviceTier, + &reasoningEffort, + &inboundEndpoint, + &upstreamEndpoint, + &cacheTTLOverridden, + &channelID, + &modelMappingChain, + &billingTier, + &billingMode, + &accountStatsCost, + &createdAt, + ); err != nil { + return nil, err + } + + log := &service.UsageLog{ + ID: id, + UserID: userID, + APIKeyID: apiKeyID, + AccountID: accountID, + Model: model, + RequestedModel: coalesceTrimmedString(requestedModel, model), + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheCreationTokens: cacheCreationTokens, + CacheReadTokens: cacheReadTokens, + CacheCreation5mTokens: cacheCreation5m, + CacheCreation1hTokens: cacheCreation1h, + ImageOutputTokens: imageOutputTokens, + ImageOutputCost: imageOutputCost, + InputCost: inputCost, + OutputCost: outputCost, + CacheCreationCost: cacheCreationCost, + CacheReadCost: cacheReadCost, + TotalCost: totalCost, + ActualCost: actualCost, + RateMultiplier: rateMultiplier, + AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier), + BillingType: int8(billingType), + RequestType: service.RequestTypeFromInt16(requestTypeRaw), + ImageCount: imageCount, + CacheTTLOverridden: cacheTTLOverridden, + CreatedAt: createdAt, + } + // 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。 + log.Stream = stream + log.OpenAIWSMode = openaiWSMode + log.RequestType = log.EffectiveRequestType() + log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode) + + if requestID.Valid { + log.RequestID = requestID.String + } + if groupID.Valid { + value := groupID.Int64 + log.GroupID = &value + } + if subscriptionID.Valid { + value := subscriptionID.Int64 + log.SubscriptionID = &value + } + if durationMs.Valid { + value := int(durationMs.Int64) + log.DurationMs = &value + } + if firstTokenMs.Valid { + value := int(firstTokenMs.Int64) + log.FirstTokenMs = &value + } + if userAgent.Valid { + log.UserAgent = &userAgent.String + } + if ipAddress.Valid { + log.IPAddress = &ipAddress.String + } + if imageSize.Valid { + log.ImageSize = &imageSize.String + } + if serviceTier.Valid { + log.ServiceTier = &serviceTier.String + } + if reasoningEffort.Valid { + log.ReasoningEffort = &reasoningEffort.String + } + if inboundEndpoint.Valid { + log.InboundEndpoint = &inboundEndpoint.String + } + if upstreamEndpoint.Valid { + log.UpstreamEndpoint = &upstreamEndpoint.String + } + if upstreamModel.Valid { + log.UpstreamModel = &upstreamModel.String + } + if channelID.Valid { + value := channelID.Int64 + log.ChannelID = &value + } + if modelMappingChain.Valid { + log.ModelMappingChain = &modelMappingChain.String + } + if billingTier.Valid { + log.BillingTier = &billingTier.String + } + if billingMode.Valid { + log.BillingMode = &billingMode.String + } + if accountStatsCost.Valid { + log.AccountStatsCost = &accountStatsCost.Float64 + } + + return log, nil +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index f2fb87da33e..9b34e20bf76 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,7 +3,6 @@ package repository import ( "context" "database/sql" - "encoding/json" "errors" "fmt" "os" @@ -14,11 +13,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - dbaccount "github.com/Wei-Shaw/sub2api/ent/account" - dbapikey "github.com/Wei-Shaw/sub2api/ent/apikey" - dbgroup "github.com/Wei-Shaw/sub2api/ent/group" - dbuser "github.com/Wei-Shaw/sub2api/ent/user" - dbusersub "github.com/Wei-Shaw/sub2api/ent/usersubscription" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" @@ -156,24 +150,6 @@ const ( usageLogBestEffortRecentTTL = 30 * time.Second ) -type usageLogCreateRequest struct { - log *service.UsageLog - prepared usageLogInsertPrepared - shared *usageLogCreateShared - resultCh chan usageLogCreateResult -} - -type usageLogCreateResult struct { - inserted bool - err error -} - -type usageLogBestEffortRequest struct { - prepared usageLogInsertPrepared - apiKeyID int64 - resultCh chan error -} - type usageLogInsertPrepared struct { createdAt time.Time requestID string @@ -243,464 +219,6 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int return requestCount / 5, tokenCount / 5, nil } -func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) { - if log == nil { - return false, nil - } - - if tx := dbent.TxFromContext(ctx); tx != nil { - return r.createSingle(ctx, tx.Client(), log) - } - requestID := strings.TrimSpace(log.RequestID) - if requestID == "" { - return r.createSingle(ctx, r.sql, log) - } - log.RequestID = requestID - return r.createBatched(ctx, log) -} - -func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error { - if log == nil { - return nil - } - - if tx := dbent.TxFromContext(ctx); tx != nil { - _, err := r.createSingle(ctx, tx.Client(), log) - return err - } - if r.db == nil { - _, err := r.createSingle(ctx, r.sql, log) - return err - } - - r.ensureBestEffortBatcher() - if r.bestEffortBatchCh == nil { - _, err := r.createSingle(ctx, r.sql, log) - return err - } - - req := usageLogBestEffortRequest{ - prepared: prepareUsageLogInsert(log), - apiKeyID: log.APIKeyID, - resultCh: make(chan error, 1), - } - if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok { - if _, exists := r.bestEffortRecent.Get(key); exists { - return nil - } - } - - select { - case r.bestEffortBatchCh <- req: - case <-ctx.Done(): - return service.MarkUsageLogCreateDropped(ctx.Err()) - default: - return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")) - } - - select { - case err := <-req.resultCh: - return err - case <-ctx.Done(): - return service.MarkUsageLogCreateDropped(ctx.Err()) - } -} - -func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { - prepared := prepareUsageLogInsert(log) - if sqlq == nil { - sqlq = r.sql - } - if ctx != nil && ctx.Err() != nil { - return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) - } - - query := ` - INSERT INTO usage_logs ( - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, - $8, $9, - $10, $11, $12, $13, - $14, $15, $16, $17, - $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 - ) - ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING id, created_at - ` - - if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil { - if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" { - selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" - if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { - return false, err - } - log.RateMultiplier = prepared.rateMultiplier - return false, nil - } else { - return false, err - } - } - log.RateMultiplier = prepared.rateMultiplier - return true, nil -} - -func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) { - if r.db == nil { - return r.createSingle(ctx, r.sql, log) - } - r.ensureCreateBatcher() - if r.createBatchCh == nil { - return r.createSingle(ctx, r.sql, log) - } - - req := usageLogCreateRequest{ - log: log, - prepared: prepareUsageLogInsert(log), - shared: &usageLogCreateShared{}, - resultCh: make(chan usageLogCreateResult, 1), - } - - select { - case r.createBatchCh <- req: - case <-ctx.Done(): - return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) - default: - return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full")) - } - - select { - case res := <-req.resultCh: - return res.inserted, res.err - case <-ctx.Done(): - if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) { - return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) - } - timer := time.NewTimer(usageLogCreateCancelWait) - defer timer.Stop() - select { - case res := <-req.resultCh: - return res.inserted, res.err - case <-timer.C: - return false, ctx.Err() - } - } -} - -func (r *usageLogRepository) ensureCreateBatcher() { - if r == nil || r.db == nil || r.createBatchCh != nil { - return - } - r.createBatchOnce.Do(func() { - r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap) - go r.runCreateBatcher(r.db) - }) -} - -func (r *usageLogRepository) ensureBestEffortBatcher() { - if r == nil || r.db == nil || r.bestEffortBatchCh != nil { - return - } - r.bestEffortBatchOnce.Do(func() { - r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap) - go r.runBestEffortBatcher(r.db) - }) -} - -func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { - for { - first, ok := <-r.createBatchCh - if !ok { - return - } - - batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize) - batch = append(batch, first) - - timer := time.NewTimer(usageLogCreateBatchWindow) - batchLoop: - for len(batch) < usageLogCreateBatchMaxSize { - select { - case req, ok := <-r.createBatchCh: - if !ok { - break batchLoop - } - batch = append(batch, req) - case <-timer.C: - break batchLoop - } - } - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - - r.flushCreateBatch(db, batch) - } -} - -func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) { - for { - first, ok := <-r.bestEffortBatchCh - if !ok { - return - } - - batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize) - batch = append(batch, first) - - timer := time.NewTimer(usageLogBestEffortBatchWindow) - bestEffortLoop: - for len(batch) < usageLogBestEffortBatchMaxSize { - select { - case req, ok := <-r.bestEffortBatchCh: - if !ok { - break bestEffortLoop - } - batch = append(batch, req) - case <-timer.C: - break bestEffortLoop - } - } - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - - r.flushBestEffortBatch(db, batch) - } -} - -func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { - if len(batch) == 0 { - return - } - - uniqueOrder := make([]string, 0, len(batch)) - preparedByKey := make(map[string]usageLogInsertPrepared, len(batch)) - requestsByKey := make(map[string][]usageLogCreateRequest, len(batch)) - fallback := make([]usageLogCreateRequest, 0) - - for _, req := range batch { - if req.log == nil { - completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) - continue - } - if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) { - if req.shared.state.Load() == usageLogCreateStateCanceled { - completeUsageLogCreateRequest(req, usageLogCreateResult{ - inserted: false, - err: service.MarkUsageLogCreateNotPersisted(context.Canceled), - }) - continue - } - } - prepared := req.prepared - if prepared.requestID == "" { - fallback = append(fallback, req) - continue - } - key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID) - if _, exists := requestsByKey[key]; !exists { - uniqueOrder = append(uniqueOrder, key) - preparedByKey[key] = prepared - } - requestsByKey[key] = append(requestsByKey[key], req) - } - - if len(uniqueOrder) > 0 { - insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) - if err != nil { - if safeFallback { - for _, key := range uniqueOrder { - fallback = append(fallback, requestsByKey[key]...) - } - } else { - for _, key := range uniqueOrder { - reqs := requestsByKey[key] - state, hasState := stateMap[key] - inserted := insertedMap[key] - for idx, req := range reqs { - req.log.RateMultiplier = preparedByKey[key].rateMultiplier - if hasState { - req.log.ID = state.ID - req.log.CreatedAt = state.CreatedAt - } - switch { - case inserted && idx == 0: - completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil}) - case inserted: - completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) - case hasState: - completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) - case idx == 0: - completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err}) - default: - completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) - } - } - } - } - } else { - for _, key := range uniqueOrder { - reqs := requestsByKey[key] - state, ok := stateMap[key] - if !ok { - for _, req := range reqs { - completeUsageLogCreateRequest(req, usageLogCreateResult{ - inserted: false, - err: fmt.Errorf("usage log batch state missing for key=%s", key), - }) - } - continue - } - for idx, req := range reqs { - req.log.ID = state.ID - req.log.CreatedAt = state.CreatedAt - req.log.RateMultiplier = preparedByKey[key].rateMultiplier - completeUsageLogCreateRequest(req, usageLogCreateResult{ - inserted: idx == 0 && insertedMap[key], - err: nil, - }) - } - } - } - } - - if len(fallback) == 0 { - return - } - - for _, req := range fallback { - fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - inserted, err := r.createSingle(fallbackCtx, db, req.log) - cancel() - completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err}) - } -} - -func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) { - if len(batch) == 0 { - return - } - - type bestEffortGroup struct { - prepared usageLogInsertPrepared - apiKeyID int64 - key string - reqs []usageLogBestEffortRequest - } - - groupsByKey := make(map[string]*bestEffortGroup, len(batch)) - groupOrder := make([]*bestEffortGroup, 0, len(batch)) - preparedList := make([]usageLogInsertPrepared, 0, len(batch)) - - for idx, req := range batch { - prepared := req.prepared - key := fmt.Sprintf("__best_effort_%d", idx) - if prepared.requestID != "" { - key = usageLogBatchKey(prepared.requestID, req.apiKeyID) - } - group, exists := groupsByKey[key] - if !exists { - group = &bestEffortGroup{ - prepared: prepared, - apiKeyID: req.apiKeyID, - key: key, - } - groupsByKey[key] = group - groupOrder = append(groupOrder, group) - preparedList = append(preparedList, prepared) - } - group.reqs = append(group.reqs, req) - } - - if len(preparedList) == 0 { - for _, req := range batch { - sendUsageLogBestEffortResult(req.resultCh, nil) - } - return - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - query, args := buildUsageLogBestEffortInsertQuery(preparedList) - if _, err := db.ExecContext(ctx, query, args...); err != nil { - logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err) - for _, group := range groupOrder { - singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared) - if singleErr != nil { - logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr) - } else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { - r.bestEffortRecent.SetDefault(group.key, struct{}{}) - } - for _, req := range group.reqs { - sendUsageLogBestEffortResult(req.resultCh, singleErr) - } - } - return - } - for _, group := range groupOrder { - if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { - r.bestEffortRecent.SetDefault(group.key, struct{}{}) - } - for _, req := range group.reqs { - sendUsageLogBestEffortResult(req.resultCh, nil) - } - } -} - func sendUsageLogBestEffortResult(ch chan error, err error) { if ch == nil { return @@ -718,588 +236,6 @@ func completeUsageLogCreateRequest(req usageLogCreateRequest, res usageLogCreate sendUsageLogCreateResult(req.resultCh, res) } -func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) { - if len(keys) == 0 { - return map[string]bool{}, map[string]usageLogBatchState{}, false, nil - } - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) - var payload []byte - if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil { - return nil, nil, true, err - } - var rows []usageLogBatchRow - if err := json.Unmarshal(payload, &rows); err != nil { - return nil, nil, false, err - } - insertedMap := make(map[string]bool, len(keys)) - stateMap := make(map[string]usageLogBatchState, len(keys)) - for _, row := range rows { - key := usageLogBatchKey(row.RequestID, row.APIKeyID) - insertedMap[key] = row.Inserted - stateMap[key] = usageLogBatchState{ - ID: row.ID, - CreatedAt: row.CreatedAt, - } - } - if len(stateMap) != len(keys) { - return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys)) - } - return insertedMap, stateMap, false, nil -} - -func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { - var query strings.Builder - _, _ = query.WriteString(` - WITH input ( - input_idx, - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - ) AS (VALUES `) - - args := make([]any, 0, len(keys)*46) - argPos := 1 - for idx, key := range keys { - if idx > 0 { - _, _ = query.WriteString(",") - } - _, _ = query.WriteString("(") - _, _ = query.WriteString("$") - _, _ = query.WriteString(strconv.Itoa(argPos)) - args = append(args, idx) - argPos++ - prepared := preparedByKey[key] - for i := 0; i < len(prepared.args); i++ { - _, _ = query.WriteString(",") - _, _ = query.WriteString("$") - _, _ = query.WriteString(strconv.Itoa(argPos)) - if i < len(usageLogInsertArgTypes) { - _, _ = query.WriteString("::") - _, _ = query.WriteString(usageLogInsertArgTypes[i]) - } - argPos++ - } - _, _ = query.WriteString(")") - args = append(args, prepared.args...) - } - _, _ = query.WriteString(` - ), - inserted AS ( - INSERT INTO usage_logs ( - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - ) - SELECT - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - FROM input - ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING request_id, api_key_id, id, created_at - ), - resolved AS ( - SELECT - input.input_idx, - input.request_id, - input.api_key_id, - COALESCE(inserted.id, existing.id) AS id, - COALESCE(inserted.created_at, existing.created_at) AS created_at, - (inserted.id IS NOT NULL) AS inserted - FROM input - LEFT JOIN inserted - ON inserted.request_id = input.request_id - AND inserted.api_key_id = input.api_key_id - LEFT JOIN usage_logs existing - ON existing.request_id = input.request_id - AND existing.api_key_id = input.api_key_id - ) - SELECT COALESCE( - json_agg( - json_build_object( - 'request_id', resolved.request_id, - 'api_key_id', resolved.api_key_id, - 'id', resolved.id, - 'created_at', resolved.created_at, - 'inserted', resolved.inserted - ) - ORDER BY resolved.input_idx - ), - '[]'::json - ) - FROM resolved - `) - return query.String(), args -} - -func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) { - var query strings.Builder - _, _ = query.WriteString(` - WITH input ( - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - ) AS (VALUES `) - - args := make([]any, 0, len(preparedList)*46) - argPos := 1 - for idx, prepared := range preparedList { - if idx > 0 { - _, _ = query.WriteString(",") - } - _, _ = query.WriteString("(") - for i := 0; i < len(prepared.args); i++ { - if i > 0 { - _, _ = query.WriteString(",") - } - _, _ = query.WriteString("$") - _, _ = query.WriteString(strconv.Itoa(argPos)) - if i < len(usageLogInsertArgTypes) { - _, _ = query.WriteString("::") - _, _ = query.WriteString(usageLogInsertArgTypes[i]) - } - argPos++ - } - _, _ = query.WriteString(")") - args = append(args, prepared.args...) - } - - _, _ = query.WriteString(` - ) - INSERT INTO usage_logs ( - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - ) - SELECT - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - FROM input - ON CONFLICT (request_id, api_key_id) DO NOTHING - `) - - return query.String(), args -} - -func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error { - _, err := sqlq.ExecContext(ctx, ` - INSERT INTO usage_logs ( - user_id, - api_key_id, - account_id, - request_id, - model, - requested_model, - upstream_model, - group_id, - subscription_id, - input_tokens, - output_tokens, - cache_creation_tokens, - cache_read_tokens, - cache_creation_5m_tokens, - cache_creation_1h_tokens, - image_output_tokens, - image_output_cost, - input_cost, - output_cost, - cache_creation_cost, - cache_read_cost, - total_cost, - actual_cost, - rate_multiplier, - account_rate_multiplier, - billing_type, - request_type, - stream, - openai_ws_mode, - duration_ms, - first_token_ms, - user_agent, - ip_address, - image_count, - image_size, - service_tier, - reasoning_effort, - inbound_endpoint, - upstream_endpoint, - cache_ttl_overridden, - channel_id, - model_mapping_chain, - billing_tier, - billing_mode, - account_stats_cost, - created_at - ) VALUES ( - $1, $2, $3, $4, $5, $6, $7, - $8, $9, - $10, $11, $12, $13, - $14, $15, $16, $17, - $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 - ) - ON CONFLICT (request_id, api_key_id) DO NOTHING - `, prepared.args...) - return err -} - -func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { - createdAt := log.CreatedAt - if createdAt.IsZero() { - createdAt = time.Now() - } - - requestID := strings.TrimSpace(log.RequestID) - log.RequestID = requestID - - rateMultiplier := log.RateMultiplier - log.SyncRequestTypeAndLegacyFields() - requestType := int16(log.RequestType) - - groupID := nullInt64(log.GroupID) - subscriptionID := nullInt64(log.SubscriptionID) - duration := nullInt(log.DurationMs) - firstToken := nullInt(log.FirstTokenMs) - userAgent := nullString(log.UserAgent) - ipAddress := nullString(log.IPAddress) - imageSize := nullString(log.ImageSize) - serviceTier := nullString(log.ServiceTier) - reasoningEffort := nullString(log.ReasoningEffort) - inboundEndpoint := nullString(log.InboundEndpoint) - upstreamEndpoint := nullString(log.UpstreamEndpoint) - channelID := nullInt64(log.ChannelID) - modelMappingChain := nullString(log.ModelMappingChain) - billingTier := nullString(log.BillingTier) - billingMode := nullString(log.BillingMode) - requestedModel := strings.TrimSpace(log.RequestedModel) - if requestedModel == "" { - requestedModel = strings.TrimSpace(log.Model) - } - upstreamModel := nullString(log.UpstreamModel) - - var requestIDArg any - if requestID != "" { - requestIDArg = requestID - } - - return usageLogInsertPrepared{ - createdAt: createdAt, - requestID: requestID, - rateMultiplier: rateMultiplier, - requestType: requestType, - args: []any{ - log.UserID, - log.APIKeyID, - log.AccountID, - requestIDArg, - log.Model, - nullString(&requestedModel), - upstreamModel, - groupID, - subscriptionID, - log.InputTokens, - log.OutputTokens, - log.CacheCreationTokens, - log.CacheReadTokens, - log.CacheCreation5mTokens, - log.CacheCreation1hTokens, - log.ImageOutputTokens, - log.ImageOutputCost, - log.InputCost, - log.OutputCost, - log.CacheCreationCost, - log.CacheReadCost, - log.TotalCost, - log.ActualCost, - rateMultiplier, - log.AccountRateMultiplier, - log.BillingType, - requestType, - log.Stream, - log.OpenAIWSMode, - duration, - firstToken, - userAgent, - ipAddress, - log.ImageCount, - imageSize, - serviceTier, - reasoningEffort, - inboundEndpoint, - upstreamEndpoint, - log.CacheTTLOverridden, - channelID, - modelMappingChain, - billingTier, - billingMode, - log.AccountStatsCost, // account_stats_cost - createdAt, - }, - } -} - func usageLogBatchKey(requestID string, apiKeyID int64) string { return requestID + "\x1f" + strconv.FormatInt(apiKeyID, 10) } @@ -1322,85 +258,6 @@ func (r *usageLogRepository) bestEffortRecentKey(requestID string, apiKeyID int6 return usageLogBatchKey(requestID, apiKeyID), true } -func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" - rows, err := r.sql.QueryContext(ctx, query, id) - if err != nil { - return nil, err - } - defer func() { - // 保持主错误优先;仅在无错误时回传 Close 失败。 - // 同时清空返回值,避免误用不完整结果。 - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - log = nil - } - }() - if !rows.Next() { - if err = rows.Err(); err != nil { - return nil, err - } - return nil, service.ErrUsageLogNotFound - } - log, err = scanUsageLog(rows) - if err != nil { - return nil, err - } - if err = rows.Err(); err != nil { - return nil, err - } - return log, nil -} - -func (r *usageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { - return r.listUsageLogsWithPagination(ctx, "WHERE user_id = $1", []any{userID}, params) -} - -func (r *usageLogRepository) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { - return r.listUsageLogsWithPagination(ctx, "WHERE api_key_id = $1", []any{apiKeyID}, params) -} - -// UserStats 用户使用统计 -type UserStats struct { - TotalRequests int64 `json:"total_requests"` - TotalTokens int64 `json:"total_tokens"` - TotalCost float64 `json:"total_cost"` - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - CacheReadTokens int64 `json:"cache_read_tokens"` -} - -func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { - query := ` - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, - COALESCE(SUM(actual_cost), 0) as total_cost, - COALESCE(SUM(input_tokens), 0) as input_tokens, - COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens - FROM usage_logs - WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 - ` - - stats := &UserStats{} - if err := scanSingleRow( - ctx, - r.sql, - query, - []any{userID, startTime, endTime}, - &stats.TotalRequests, - &stats.TotalTokens, - &stats.TotalCost, - &stats.InputTokens, - &stats.OutputTokens, - &stats.CacheReadTokens, - ); err != nil { - return nil, err - } - return stats, nil -} - // DashboardStats 仪表盘统计 type DashboardStats = usagestats.DashboardStats @@ -1697,282 +554,27 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co return nil } -func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { - return r.listUsageLogsWithPagination(ctx, "WHERE account_id = $1", []any{accountID}, params) +// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。 +// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。 +func resolveUsageStatsTimezone() string { + tzName := timezone.Name() + if tzName != "" && tzName != "Local" { + return tzName + } + if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" { + return envTZ + } + return "UTC" } -func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" - logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) - return logs, nil, err +func (r *usageLogRepository) Delete(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE id = $1", id) + return err } - -// GetUserStatsAggregated returns aggregated usage statistics for a user using database-level aggregation -func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { - query := ` - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms - FROM usage_logs - WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 - ` - - var stats usagestats.UsageStats - if err := scanSingleRow( - ctx, - r.sql, - query, - []any{userID, startTime, endTime}, - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { - return nil, err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens - return &stats, nil -} - -// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation -func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { - query := ` - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms - FROM usage_logs - WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 - ` - - var stats usagestats.UsageStats - if err := scanSingleRow( - ctx, - r.sql, - query, - []any{apiKeyID, startTime, endTime}, - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { - return nil, err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens - return &stats, nil -} - -// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据 -// -// 性能优化说明: -// 原实现先查询所有日志记录,再在应用层循环计算统计值: -// 1. 需要传输大量数据到应用层 -// 2. 应用层循环计算增加 CPU 和内存开销 -// -// 新实现使用 SQL 聚合函数: -// 1. 在数据库层完成 COUNT/SUM/AVG 计算 -// 2. 只返回单行聚合结果,大幅减少数据传输量 -// 3. 利用数据库索引优化聚合查询性能 -func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { - query := ` - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms - FROM usage_logs - WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 - ` - - var stats usagestats.UsageStats - if err := scanSingleRow( - ctx, - r.sql, - query, - []any{accountID, startTime, endTime}, - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { - return nil, err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens - return &stats, nil -} - -// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 -// 性能优化:数据库层聚合计算,避免应用层循环统计 -func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { - query := fmt.Sprintf(` - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms - FROM usage_logs - WHERE %s = $1 AND created_at >= $2 AND created_at < $3 - `, rawUsageLogModelColumn) - - var stats usagestats.UsageStats - if err := scanSingleRow( - ctx, - r.sql, - query, - []any{modelName, startTime, endTime}, - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { - return nil, err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens - return &stats, nil -} - -// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据 -// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计 -func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) { - tzName := resolveUsageStatsTimezone() - query := ` - SELECT - -- 使用应用时区分组,避免数据库会话时区导致日边界偏移。 - TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date, - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms - FROM usage_logs - WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 - GROUP BY 1 - ORDER BY 1 - ` - - rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName) - if err != nil { - return nil, err - } - defer func() { - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - result = nil - } - }() - - result = make([]map[string]any, 0) - for rows.Next() { - var ( - date string - totalRequests int64 - totalInputTokens int64 - totalOutputTokens int64 - totalCacheTokens int64 - totalCost float64 - totalActualCost float64 - avgDurationMs float64 - ) - if err = rows.Scan( - &date, - &totalRequests, - &totalInputTokens, - &totalOutputTokens, - &totalCacheTokens, - &totalCost, - &totalActualCost, - &avgDurationMs, - ); err != nil { - return nil, err - } - result = append(result, map[string]any{ - "date": date, - "total_requests": totalRequests, - "total_input_tokens": totalInputTokens, - "total_output_tokens": totalOutputTokens, - "total_cache_tokens": totalCacheTokens, - "total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens, - "total_cost": totalCost, - "total_actual_cost": totalActualCost, - "average_duration_ms": avgDurationMs, - }) - } - - if err = rows.Err(); err != nil { - return nil, err - } - - return result, nil -} - -// resolveUsageStatsTimezone 获取用于 SQL 分组的时区名称。 -// 优先使用应用初始化的时区,其次尝试读取 TZ 环境变量,最后回落为 UTC。 -func resolveUsageStatsTimezone() string { - tzName := timezone.Name() - if tzName != "" && tzName != "Local" { - return tzName - } - if envTZ := strings.TrimSpace(os.Getenv("TZ")); envTZ != "" { - return envTZ - } - return "UTC" -} - -func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" - logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) - return logs, nil, err -} - -func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" - logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) - return logs, nil, err -} - -func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := fmt.Sprintf("SELECT %s FROM usage_logs WHERE %s = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000", usageLogSelectColumns, rawUsageLogModelColumn) - logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) - return logs, nil, err -} - -func (r *usageLogRepository) Delete(ctx context.Context, id int64) error { - _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE id = $1", id) - return err -} - -// GetAccountTodayStats 获取账号今日统计 -func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { - today := timezone.Today() + +// GetAccountTodayStats 获取账号今日统计 +func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { + today := timezone.Today() query := ` SELECT @@ -2156,556 +758,41 @@ type UserSpendingRankingResponse = usagestats.UserSpendingRankingResponse // APIKeyUsageTrendPoint represents API key usage trend data point type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint -// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date -func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { - dateFormat := safeDateFormat(granularity) +// UserDashboardStats 用户仪表盘统计 +type UserDashboardStats = usagestats.UserDashboardStats - query := fmt.Sprintf(` - WITH top_keys AS ( - SELECT api_key_id - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - GROUP BY api_key_id - ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC - LIMIT $3 - ) +// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值) +func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) { + fiveMinutesAgo := time.Now().Add(-5 * time.Minute) + query := ` SELECT - TO_CHAR(u.created_at, '%s') as date, - u.api_key_id, - COALESCE(k.name, '') as key_name, - COUNT(*) as requests, - COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens - FROM usage_logs u - LEFT JOIN api_keys k ON u.api_key_id = k.id - WHERE u.api_key_id IN (SELECT api_key_id FROM top_keys) - AND u.created_at >= $4 AND u.created_at < $5 - GROUP BY date, u.api_key_id, k.name - ORDER BY date ASC, tokens DESC - `, dateFormat) - - rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime) - if err != nil { - return nil, err - } - defer func() { - // 保持主错误优先;仅在无错误时回传 Close 失败。 - // 同时清空返回值,避免误用不完整结果。 - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil - } - }() + COUNT(*) as request_count, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count + FROM usage_logs + WHERE created_at >= $1 AND api_key_id = $2` + args := []any{fiveMinutesAgo, apiKeyID} - results = make([]APIKeyUsageTrendPoint, 0) - for rows.Next() { - var row APIKeyUsageTrendPoint - if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { - return nil, err - } - results = append(results, row) - } - if err = rows.Err(); err != nil { - return nil, err + var requestCount int64 + var tokenCount int64 + if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil { + return 0, 0, err } - - return results, nil + return requestCount / 5, tokenCount / 5, nil } -// GetUserUsageTrend returns usage trend data grouped by user and date -func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { - dateFormat := safeDateFormat(granularity) - - query := fmt.Sprintf(` - WITH top_users AS ( - SELECT user_id - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - GROUP BY user_id - ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC - LIMIT $3 - ) - SELECT - TO_CHAR(u.created_at, '%s') as date, - u.user_id, - COALESCE(us.email, '') as email, - COALESCE(us.username, '') as username, - COUNT(*) as requests, - COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens, - COALESCE(SUM(u.total_cost), 0) as cost, - COALESCE(SUM(u.actual_cost), 0) as actual_cost - FROM usage_logs u - LEFT JOIN users us ON u.user_id = us.id - WHERE u.user_id IN (SELECT user_id FROM top_users) - AND u.created_at >= $4 AND u.created_at < $5 - GROUP BY date, u.user_id, us.email, us.username - ORDER BY date ASC, tokens DESC - `, dateFormat) - - rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime) - if err != nil { - return nil, err - } - defer func() { - // 保持主错误优先;仅在无错误时回传 Close 失败。 - // 同时清空返回值,避免误用不完整结果。 - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil - } - }() +// UsageLogFilters represents filters for usage log queries +type UsageLogFilters = usagestats.UsageLogFilters - results = make([]UserUsageTrendPoint, 0) - for rows.Next() { - var row UserUsageTrendPoint - if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Username, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { - return nil, err - } - results = append(results, row) - } - if err = rows.Err(); err != nil { - return nil, err +func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool { + if filters.ExactTotal { + return false } - - return results, nil + // 强选择过滤下记录集通常较小,保留精确总数。 + return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0 } -// GetUserSpendingRanking returns user spending ranking aggregated within the time range. -func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (result *UserSpendingRankingResponse, err error) { - if limit <= 0 { - limit = 12 - } - - query := ` - WITH user_spend AS ( - SELECT - u.user_id, - COALESCE(us.email, '') as email, - COALESCE(SUM(u.actual_cost), 0) as actual_cost, - COUNT(*) as requests, - COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens - FROM usage_logs u - LEFT JOIN users us ON u.user_id = us.id - WHERE u.created_at >= $1 AND u.created_at < $2 - GROUP BY u.user_id, us.email - ), - ranked AS ( - SELECT - user_id, - email, - actual_cost, - requests, - tokens, - COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost, - COALESCE(SUM(requests) OVER (), 0) as total_requests, - COALESCE(SUM(tokens) OVER (), 0) as total_tokens - FROM user_spend - ORDER BY actual_cost DESC, tokens DESC, user_id ASC - LIMIT $3 - ) - SELECT - user_id, - email, - actual_cost, - requests, - tokens, - total_actual_cost, - total_requests, - total_tokens - FROM ranked - ORDER BY actual_cost DESC, tokens DESC, user_id ASC - ` - - rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit) - if err != nil { - return nil, err - } - defer func() { - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - result = nil - } - }() - - ranking := make([]UserSpendingRankingItem, 0) - totalActualCost := 0.0 - totalRequests := int64(0) - totalTokens := int64(0) - for rows.Next() { - var row UserSpendingRankingItem - if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil { - return nil, err - } - ranking = append(ranking, row) - } - if err = rows.Err(); err != nil { - return nil, err - } - - return &UserSpendingRankingResponse{ - Ranking: ranking, - TotalActualCost: totalActualCost, - TotalRequests: totalRequests, - TotalTokens: totalTokens, - }, nil -} - -// UserDashboardStats 用户仪表盘统计 -type UserDashboardStats = usagestats.UserDashboardStats - -// GetUserDashboardStats 获取用户专属的仪表盘统计 -func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { - stats := &UserDashboardStats{} - today := timezone.Today() - - // API Key 统计 - if err := scanSingleRow( - ctx, - r.sql, - "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", - []any{userID}, - &stats.TotalAPIKeys, - ); err != nil { - return nil, err - } - if err := scanSingleRow( - ctx, - r.sql, - "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", - []any{userID, service.StatusActive}, - &stats.ActiveAPIKeys, - ); err != nil { - return nil, err - } - - // 累计 Token 统计 - totalStatsQuery := ` - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(duration_ms), 0) as avg_duration_ms - FROM usage_logs - WHERE user_id = $1 - ` - if err := scanSingleRow( - ctx, - r.sql, - totalStatsQuery, - []any{userID}, - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheCreationTokens, - &stats.TotalCacheReadTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { - return nil, err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens - - // 今日 Token 统计 - todayStatsQuery := ` - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(input_tokens), 0) as today_input_tokens, - COALESCE(SUM(output_tokens), 0) as today_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as today_cost, - COALESCE(SUM(actual_cost), 0) as today_actual_cost - FROM usage_logs - WHERE user_id = $1 AND created_at >= $2 - ` - if err := scanSingleRow( - ctx, - r.sql, - todayStatsQuery, - []any{userID, today}, - &stats.TodayRequests, - &stats.TodayInputTokens, - &stats.TodayOutputTokens, - &stats.TodayCacheCreationTokens, - &stats.TodayCacheReadTokens, - &stats.TodayCost, - &stats.TodayActualCost, - ); err != nil { - return nil, err - } - stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - - // 性能指标:RPM 和 TPM(最近1分钟,仅统计该用户的请求) - rpm, tpm, err := r.getPerformanceStats(ctx, userID) - if err != nil { - return nil, err - } - stats.Rpm = rpm - stats.Tpm = tpm - - return stats, nil -} - -// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值) -func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) { - fiveMinutesAgo := time.Now().Add(-5 * time.Minute) - query := ` - SELECT - COUNT(*) as request_count, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count - FROM usage_logs - WHERE created_at >= $1 AND api_key_id = $2` - args := []any{fiveMinutesAgo, apiKeyID} - - var requestCount int64 - var tokenCount int64 - if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil { - return 0, 0, err - } - return requestCount / 5, tokenCount / 5, nil -} - -// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤) -func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) { - stats := &UserDashboardStats{} - today := timezone.Today() - - // API Key 维度不需要统计 key 数量,设为 1 - stats.TotalAPIKeys = 1 - stats.ActiveAPIKeys = 1 - - // 累计 Token 统计 - totalStatsQuery := ` - SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(duration_ms), 0) as avg_duration_ms - FROM usage_logs - WHERE api_key_id = $1 - ` - if err := scanSingleRow( - ctx, - r.sql, - totalStatsQuery, - []any{apiKeyID}, - &stats.TotalRequests, - &stats.TotalInputTokens, - &stats.TotalOutputTokens, - &stats.TotalCacheCreationTokens, - &stats.TotalCacheReadTokens, - &stats.TotalCost, - &stats.TotalActualCost, - &stats.AverageDurationMs, - ); err != nil { - return nil, err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens - - // 今日 Token 统计 - todayStatsQuery := ` - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(input_tokens), 0) as today_input_tokens, - COALESCE(SUM(output_tokens), 0) as today_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as today_cost, - COALESCE(SUM(actual_cost), 0) as today_actual_cost - FROM usage_logs - WHERE api_key_id = $1 AND created_at >= $2 - ` - if err := scanSingleRow( - ctx, - r.sql, - todayStatsQuery, - []any{apiKeyID, today}, - &stats.TodayRequests, - &stats.TodayInputTokens, - &stats.TodayOutputTokens, - &stats.TodayCacheCreationTokens, - &stats.TodayCacheReadTokens, - &stats.TodayCost, - &stats.TodayActualCost, - ); err != nil { - return nil, err - } - stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - - // 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤) - rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID) - if err != nil { - return nil, err - } - stats.Rpm = rpm - stats.Tpm = tpm - - return stats, nil -} - -// GetUserUsageTrendByUserID 获取指定用户的使用趋势 -func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { - dateFormat := safeDateFormat(granularity) - - query := fmt.Sprintf(` - SELECT - TO_CHAR(created_at, '%s') as date, - COUNT(*) as requests, - COALESCE(SUM(input_tokens), 0) as input_tokens, - COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, - COALESCE(SUM(total_cost), 0) as cost, - COALESCE(SUM(actual_cost), 0) as actual_cost - FROM usage_logs - WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 - GROUP BY date - ORDER BY date ASC - `, dateFormat) - - rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) - if err != nil { - return nil, err - } - defer func() { - // 保持主错误优先;仅在无错误时回传 Close 失败。 - // 同时清空返回值,避免误用不完整结果。 - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil - } - }() - - results, err = scanTrendRows(rows) - if err != nil { - return nil, err - } - return results, nil -} - -// GetUserModelStats 获取指定用户的模型统计 -func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) { - query := ` - SELECT - model, - COUNT(*) as requests, - COALESCE(SUM(input_tokens), 0) as input_tokens, - COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, - COALESCE(SUM(total_cost), 0) as cost, - COALESCE(SUM(actual_cost), 0) as actual_cost, - COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost - FROM usage_logs - WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 - GROUP BY model - ORDER BY total_tokens DESC - ` - - rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) - if err != nil { - return nil, err - } - defer func() { - // 保持主错误优先;仅在无错误时回传 Close 失败。 - // 同时清空返回值,避免误用不完整结果。 - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil - } - }() - - results, err = scanModelStatsRows(rows) - if err != nil { - return nil, err - } - return results, nil -} - -// UsageLogFilters represents filters for usage log queries -type UsageLogFilters = usagestats.UsageLogFilters - -// ListWithFilters lists usage logs with optional filters (for admin) -func (r *usageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { - conditions := make([]string, 0, 9) - args := make([]any, 0, 9) - - if filters.UserID > 0 { - conditions = append(conditions, fmt.Sprintf("user_id = $%d", len(args)+1)) - args = append(args, filters.UserID) - } - if filters.APIKeyID > 0 { - conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", len(args)+1)) - args = append(args, filters.APIKeyID) - } - if filters.AccountID > 0 { - conditions = append(conditions, fmt.Sprintf("account_id = $%d", len(args)+1)) - args = append(args, filters.AccountID) - } - if filters.GroupID > 0 { - conditions = append(conditions, fmt.Sprintf("group_id = $%d", len(args)+1)) - args = append(args, filters.GroupID) - } - conditions, args = appendRawUsageLogModelWhereCondition(conditions, args, filters.Model) - conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) - if filters.BillingType != nil { - conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) - args = append(args, int16(*filters.BillingType)) - } - if filters.BillingMode != "" { - conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) - args = append(args, filters.BillingMode) - } - if filters.StartTime != nil { - conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) - args = append(args, *filters.StartTime) - } - if filters.EndTime != nil { - conditions = append(conditions, fmt.Sprintf("created_at < $%d", len(args)+1)) - args = append(args, *filters.EndTime) - } - - whereClause := buildWhere(conditions) - var ( - logs []service.UsageLog - page *pagination.PaginationResult - err error - ) - if shouldUseFastUsageLogTotal(filters) { - logs, page, err = r.listUsageLogsWithFastPagination(ctx, whereClause, args, params) - } else { - logs, page, err = r.listUsageLogsWithPagination(ctx, whereClause, args, params) - } - if err != nil { - return nil, nil, err - } - - if err := r.hydrateUsageLogAssociations(ctx, logs); err != nil { - return nil, nil, err - } - return logs, page, nil -} - -func shouldUseFastUsageLogTotal(filters UsageLogFilters) bool { - if filters.ExactTotal { - return false - } - // 强选择过滤下记录集通常较小,保留精确总数。 - return filters.UserID == 0 && filters.APIKeyID == 0 && filters.AccountID == 0 -} - -// UsageStats represents usage statistics -type UsageStats = usagestats.UsageStats +// UsageStats represents usage statistics +type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats @@ -2715,212 +802,23 @@ func normalizePositiveInt64IDs(ids []int64) []int64 { return nil } seen := make(map[int64]struct{}, len(ids)) - out := make([]int64, 0, len(ids)) - for _, id := range ids { - if id <= 0 { - continue - } - if _, ok := seen[id]; ok { - continue - } - seen[id] = struct{}{} - out = append(out, id) - } - return out -} - -// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. -// If startTime is zero, defaults to 30 days ago. -func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { - result := make(map[int64]*BatchUserUsageStats) - normalizedUserIDs := normalizePositiveInt64IDs(userIDs) - if len(normalizedUserIDs) == 0 { - return result, nil - } - - // 默认最近 30 天 - if startTime.IsZero() { - startTime = time.Now().AddDate(0, 0, -30) - } - if endTime.IsZero() { - endTime = time.Now() - } - - for _, id := range normalizedUserIDs { - result[id] = &BatchUserUsageStats{UserID: id} - } - - query := ` - SELECT - user_id, - COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, - COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost - FROM usage_logs - WHERE user_id = ANY($1) - AND created_at >= LEAST($2, $4) - GROUP BY user_id - ` - today := timezone.Today() - rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today) - if err != nil { - return nil, err - } - for rows.Next() { - var userID int64 - var total float64 - var todayTotal float64 - if err := rows.Scan(&userID, &total, &todayTotal); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[userID]; ok { - stats.TotalActualCost = total - stats.TodayActualCost = todayTotal - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - return result, nil -} - -// BatchAPIKeyUsageStats represents usage stats for a single API key -type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats - -// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range. -// If startTime is zero, defaults to 30 days ago. -func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { - result := make(map[int64]*BatchAPIKeyUsageStats) - normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs) - if len(normalizedAPIKeyIDs) == 0 { - return result, nil - } - - // 默认最近 30 天 - if startTime.IsZero() { - startTime = time.Now().AddDate(0, 0, -30) - } - if endTime.IsZero() { - endTime = time.Now() - } - - for _, id := range normalizedAPIKeyIDs { - result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} - } - - query := ` - SELECT - api_key_id, - COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, - COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost - FROM usage_logs - WHERE api_key_id = ANY($1) - AND created_at >= LEAST($2, $4) - GROUP BY api_key_id - ` - today := timezone.Today() - rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today) - if err != nil { - return nil, err - } - for rows.Next() { - var apiKeyID int64 - var total float64 - var todayTotal float64 - if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil { - _ = rows.Close() - return nil, err - } - if stats, ok := result[apiKeyID]; ok { - stats.TotalActualCost = total - stats.TodayActualCost = todayTotal - } - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - - return result, nil -} - -// GetUsageTrendWithFilters returns usage trend data with optional filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { - if shouldUsePreaggregatedTrend(granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) { - aggregated, aggregatedErr := r.getUsageTrendFromAggregates(ctx, startTime, endTime, granularity) - if aggregatedErr == nil && len(aggregated) > 0 { - return aggregated, nil - } - } - - dateFormat := safeDateFormat(granularity) - - query := fmt.Sprintf(` - SELECT - TO_CHAR(created_at, '%s') as date, - COUNT(*) as requests, - COALESCE(SUM(input_tokens), 0) as input_tokens, - COALESCE(SUM(output_tokens), 0) as output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, - COALESCE(SUM(total_cost), 0) as cost, - COALESCE(SUM(actual_cost), 0) as actual_cost - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - `, dateFormat) - - args := []any{startTime, endTime} - if userID > 0 { - query += fmt.Sprintf(" AND user_id = $%d", len(args)+1) - args = append(args, userID) - } - if apiKeyID > 0 { - query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1) - args = append(args, apiKeyID) - } - if accountID > 0 { - query += fmt.Sprintf(" AND account_id = $%d", len(args)+1) - args = append(args, accountID) - } - if groupID > 0 { - query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) - args = append(args, groupID) - } - query, args = appendRawUsageLogModelQueryFilter(query, args, model) - query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) - if billingType != nil { - query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) - args = append(args, int16(*billingType)) - } - query += " GROUP BY date ORDER BY date ASC" - - rows, err := r.sql.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { - // 保持主错误优先;仅在无错误时回传 Close 失败。 - // 同时清空返回值,避免误用不完整结果。 - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil + out := make([]int64, 0, len(ids)) + for _, id := range ids { + if id <= 0 { + continue } - }() - - results, err = scanTrendRows(rows) - if err != nil { - return nil, err + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) } - return results, nil + return out } +// BatchAPIKeyUsageStats represents usage stats for a single API key +type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats + func shouldUsePreaggregatedTrend(granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) bool { if granularity != "day" && granularity != "hour" { return false @@ -2995,11 +893,6 @@ func (r *usageLogRepository) getUsageTrendFromAggregates(ctx context.Context, st return results, nil } -// GetModelStatsWithFilters returns model statistics with optional filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { - return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested) -} - // GetModelStatsWithFiltersBySource returns model statistics with optional filters and model source dimension. // source: requested | upstream | mapping. func (r *usageLogRepository) GetModelStatsWithFiltersBySource(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8, source string) (results []ModelStat, err error) { @@ -3075,207 +968,6 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex return results, nil } -// GetGroupStatsWithFilters returns group usage statistics with optional filters -func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []usagestats.GroupStat, err error) { - query := ` - SELECT - COALESCE(ul.group_id, 0) as group_id, - COALESCE(g.name, '') as group_name, - COUNT(*) as requests, - COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, - COALESCE(SUM(ul.total_cost), 0) as cost, - COALESCE(SUM(ul.actual_cost), 0) as actual_cost, - COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost - FROM usage_logs ul - LEFT JOIN groups g ON g.id = ul.group_id - WHERE ul.created_at >= $1 AND ul.created_at < $2 - ` - - args := []any{startTime, endTime} - if userID > 0 { - query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) - args = append(args, userID) - } - if apiKeyID > 0 { - query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) - args = append(args, apiKeyID) - } - if accountID > 0 { - query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) - args = append(args, accountID) - } - if groupID > 0 { - query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) - args = append(args, groupID) - } - query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) - if billingType != nil { - query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) - args = append(args, int16(*billingType)) - } - query += " GROUP BY ul.group_id, g.name ORDER BY total_tokens DESC" - - rows, err := r.sql.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil - } - }() - - results = make([]usagestats.GroupStat, 0) - for rows.Next() { - var row usagestats.GroupStat - if err := rows.Scan( - &row.GroupID, - &row.GroupName, - &row.Requests, - &row.TotalTokens, - &row.Cost, - &row.ActualCost, - &row.AccountCost, - ); err != nil { - return nil, err - } - results = append(results, row) - } - if err := rows.Err(); err != nil { - return nil, err - } - return results, nil -} - -// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension. -func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) { - query := ` - SELECT - COALESCE(ul.user_id, 0) as user_id, - COALESCE(u.email, '') as email, - COUNT(*) as requests, - COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, - COALESCE(SUM(ul.total_cost), 0) as cost, - COALESCE(SUM(ul.actual_cost), 0) as actual_cost, - COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost - FROM usage_logs ul - LEFT JOIN users u ON u.id = ul.user_id - WHERE ul.created_at >= $1 AND ul.created_at < $2 - ` - args := []any{startTime, endTime} - - if dim.GroupID > 0 { - query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) - args = append(args, dim.GroupID) - } - if dim.Model != "" { - query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1) - args = append(args, dim.Model) - } - if dim.Endpoint != "" { - col := resolveEndpointColumn(dim.EndpointType) - query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) - args = append(args, dim.Endpoint) - } - if dim.UserID > 0 { - query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) - args = append(args, dim.UserID) - } - if dim.APIKeyID > 0 { - query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) - args = append(args, dim.APIKeyID) - } - if dim.AccountID > 0 { - query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) - args = append(args, dim.AccountID) - } - if dim.RequestType != nil { - query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1) - args = append(args, *dim.RequestType) - } - if dim.Stream != nil { - query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1) - args = append(args, *dim.Stream) - } - if dim.BillingType != nil { - query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) - args = append(args, *dim.BillingType) - } - - query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" - if limit > 0 { - query += fmt.Sprintf(" LIMIT %d", limit) - } - - rows, err := r.sql.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil - } - }() - - results = make([]usagestats.UserBreakdownItem, 0) - for rows.Next() { - var row usagestats.UserBreakdownItem - if err := rows.Scan( - &row.UserID, - &row.Email, - &row.Requests, - &row.TotalTokens, - &row.Cost, - &row.ActualCost, - &row.AccountCost, - ); err != nil { - return nil, err - } - results = append(results, row) - } - if err := rows.Err(); err != nil { - return nil, err - } - return results, nil -} - -// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group. -// todayStart is the start-of-day in the caller's timezone (UTC-based). -// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation. -// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s) -// or a materialized view / pre-aggregation table for cumulative costs. -func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { - query := ` - SELECT - g.id AS group_id, - COALESCE(SUM(ul.actual_cost), 0) AS total_cost, - COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost - FROM groups g - LEFT JOIN usage_logs ul ON ul.group_id = g.id - GROUP BY g.id - ` - - rows, err := r.sql.QueryContext(ctx, query, todayStart) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - var results []usagestats.GroupUsageSummary - for rows.Next() { - var row usagestats.GroupUsageSummary - if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil { - return nil, err - } - results = append(results, row) - } - if err := rows.Err(); err != nil { - return nil, err - } - return results, nil -} - // resolveModelDimensionExpression maps model source type to a safe SQL expression. func resolveModelDimensionExpression(modelType string) string { requestedExpr := "COALESCE(NULLIF(TRIM(requested_model), ''), model)" @@ -3785,53 +1477,6 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID return resp, nil } -func (r *usageLogRepository) listUsageLogsWithPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { - countQuery := "SELECT COUNT(*) FROM usage_logs " + whereClause - var total int64 - if err := scanSingleRow(ctx, r.sql, countQuery, args, &total); err != nil { - return nil, nil, err - } - - limitPos := len(args) + 1 - offsetPos := len(args) + 2 - listArgs := append(append([]any{}, args...), params.Limit(), params.Offset()) - query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos) - logs, err := r.queryUsageLogs(ctx, query, listArgs...) - if err != nil { - return nil, nil, err - } - return logs, paginationResultFromTotal(total, params), nil -} - -func (r *usageLogRepository) listUsageLogsWithFastPagination(ctx context.Context, whereClause string, args []any, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { - limit := params.Limit() - offset := params.Offset() - - limitPos := len(args) + 1 - offsetPos := len(args) + 2 - listArgs := append(append([]any{}, args...), limit+1, offset) - query := fmt.Sprintf("SELECT %s FROM usage_logs %s ORDER BY %s LIMIT $%d OFFSET $%d", usageLogSelectColumns, whereClause, usageLogOrderBy(params), limitPos, offsetPos) - - logs, err := r.queryUsageLogs(ctx, query, listArgs...) - if err != nil { - return nil, nil, err - } - - hasMore := false - if len(logs) > limit { - hasMore = true - logs = logs[:limit] - } - - total := int64(offset) + int64(len(logs)) - if hasMore { - // 只保证“还有下一页”,避免对超大表做全量 COUNT(*)。 - total = int64(offset) + int64(limit) + 1 - } - - return logs, paginationResultFromTotal(total, params), nil -} - func usageLogOrderBy(params pagination.PaginationParams) string { sortBy := strings.ToLower(strings.TrimSpace(params.SortBy)) sortOrder := strings.ToUpper(params.NormalizedSortOrder(pagination.SortOrderDesc)) @@ -3852,87 +1497,6 @@ func usageLogOrderBy(params pagination.PaginationParams) string { return fmt.Sprintf("%s %s, id %s", column, sortOrder, sortOrder) } -func (r *usageLogRepository) queryUsageLogs(ctx context.Context, query string, args ...any) (logs []service.UsageLog, err error) { - rows, err := r.sql.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { - // 保持主错误优先;仅在无错误时回传 Close 失败。 - // 同时清空返回值,避免误用不完整结果。 - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - logs = nil - } - }() - - logs = make([]service.UsageLog, 0) - for rows.Next() { - var log *service.UsageLog - log, err = scanUsageLog(rows) - if err != nil { - return nil, err - } - logs = append(logs, *log) - } - if err = rows.Err(); err != nil { - return nil, err - } - return logs, nil -} - -func (r *usageLogRepository) hydrateUsageLogAssociations(ctx context.Context, logs []service.UsageLog) error { - // 关联数据使用 Ent 批量加载,避免把复杂 SQL 继续膨胀。 - if len(logs) == 0 { - return nil - } - - ids := collectUsageLogIDs(logs) - users, err := r.loadUsers(ctx, ids.userIDs) - if err != nil { - return err - } - apiKeys, err := r.loadAPIKeys(ctx, ids.apiKeyIDs) - if err != nil { - return err - } - accounts, err := r.loadAccounts(ctx, ids.accountIDs) - if err != nil { - return err - } - groups, err := r.loadGroups(ctx, ids.groupIDs) - if err != nil { - return err - } - subs, err := r.loadSubscriptions(ctx, ids.subscriptionIDs) - if err != nil { - return err - } - - for i := range logs { - if user, ok := users[logs[i].UserID]; ok { - logs[i].User = user - } - if key, ok := apiKeys[logs[i].APIKeyID]; ok { - logs[i].APIKey = key - } - if acc, ok := accounts[logs[i].AccountID]; ok { - logs[i].Account = acc - } - if logs[i].GroupID != nil { - if group, ok := groups[*logs[i].GroupID]; ok { - logs[i].Group = group - } - } - if logs[i].SubscriptionID != nil { - if sub, ok := subs[*logs[i].SubscriptionID]; ok { - logs[i].Subscription = sub - } - } - } - return nil -} - type usageLogIDs struct { userIDs []int64 apiKeyIDs []int64 @@ -3971,282 +1535,6 @@ func collectUsageLogIDs(logs []service.UsageLog) usageLogIDs { } } -func (r *usageLogRepository) loadUsers(ctx context.Context, ids []int64) (map[int64]*service.User, error) { - out := make(map[int64]*service.User) - if len(ids) == 0 { - return out, nil - } - models, err := r.client.User.Query().Where(dbuser.IDIn(ids...)).All(ctx) - if err != nil { - return nil, err - } - for _, m := range models { - out[m.ID] = userEntityToService(m) - } - return out, nil -} - -func (r *usageLogRepository) loadAPIKeys(ctx context.Context, ids []int64) (map[int64]*service.APIKey, error) { - out := make(map[int64]*service.APIKey) - if len(ids) == 0 { - return out, nil - } - models, err := r.client.APIKey.Query().Where(dbapikey.IDIn(ids...)).All(ctx) - if err != nil { - return nil, err - } - for _, m := range models { - out[m.ID] = apiKeyEntityToService(m) - } - return out, nil -} - -func (r *usageLogRepository) loadAccounts(ctx context.Context, ids []int64) (map[int64]*service.Account, error) { - out := make(map[int64]*service.Account) - if len(ids) == 0 { - return out, nil - } - models, err := r.client.Account.Query().Where(dbaccount.IDIn(ids...)).All(ctx) - if err != nil { - return nil, err - } - for _, m := range models { - out[m.ID] = accountEntityToService(m) - } - return out, nil -} - -func (r *usageLogRepository) loadGroups(ctx context.Context, ids []int64) (map[int64]*service.Group, error) { - out := make(map[int64]*service.Group) - if len(ids) == 0 { - return out, nil - } - models, err := r.client.Group.Query().Where(dbgroup.IDIn(ids...)).All(ctx) - if err != nil { - return nil, err - } - for _, m := range models { - out[m.ID] = groupEntityToService(m) - } - return out, nil -} - -func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64) (map[int64]*service.UserSubscription, error) { - out := make(map[int64]*service.UserSubscription) - if len(ids) == 0 { - return out, nil - } - models, err := r.client.UserSubscription.Query().Where(dbusersub.IDIn(ids...)).All(ctx) - if err != nil { - return nil, err - } - for _, m := range models { - out[m.ID] = userSubscriptionEntityToService(m) - } - return out, nil -} - -func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) { - var ( - id int64 - userID int64 - apiKeyID int64 - accountID int64 - requestID sql.NullString - model string - requestedModel sql.NullString - upstreamModel sql.NullString - groupID sql.NullInt64 - subscriptionID sql.NullInt64 - inputTokens int - outputTokens int - cacheCreationTokens int - cacheReadTokens int - cacheCreation5m int - cacheCreation1h int - imageOutputTokens int - imageOutputCost float64 - inputCost float64 - outputCost float64 - cacheCreationCost float64 - cacheReadCost float64 - totalCost float64 - actualCost float64 - rateMultiplier float64 - accountRateMultiplier sql.NullFloat64 - billingType int16 - requestTypeRaw int16 - stream bool - openaiWSMode bool - durationMs sql.NullInt64 - firstTokenMs sql.NullInt64 - userAgent sql.NullString - ipAddress sql.NullString - imageCount int - imageSize sql.NullString - serviceTier sql.NullString - reasoningEffort sql.NullString - inboundEndpoint sql.NullString - upstreamEndpoint sql.NullString - cacheTTLOverridden bool - channelID sql.NullInt64 - modelMappingChain sql.NullString - billingTier sql.NullString - billingMode sql.NullString - accountStatsCost sql.NullFloat64 - createdAt time.Time - ) - - if err := scanner.Scan( - &id, - &userID, - &apiKeyID, - &accountID, - &requestID, - &model, - &requestedModel, - &upstreamModel, - &groupID, - &subscriptionID, - &inputTokens, - &outputTokens, - &cacheCreationTokens, - &cacheReadTokens, - &cacheCreation5m, - &cacheCreation1h, - &imageOutputTokens, - &imageOutputCost, - &inputCost, - &outputCost, - &cacheCreationCost, - &cacheReadCost, - &totalCost, - &actualCost, - &rateMultiplier, - &accountRateMultiplier, - &billingType, - &requestTypeRaw, - &stream, - &openaiWSMode, - &durationMs, - &firstTokenMs, - &userAgent, - &ipAddress, - &imageCount, - &imageSize, - &serviceTier, - &reasoningEffort, - &inboundEndpoint, - &upstreamEndpoint, - &cacheTTLOverridden, - &channelID, - &modelMappingChain, - &billingTier, - &billingMode, - &accountStatsCost, - &createdAt, - ); err != nil { - return nil, err - } - - log := &service.UsageLog{ - ID: id, - UserID: userID, - APIKeyID: apiKeyID, - AccountID: accountID, - Model: model, - RequestedModel: coalesceTrimmedString(requestedModel, model), - InputTokens: inputTokens, - OutputTokens: outputTokens, - CacheCreationTokens: cacheCreationTokens, - CacheReadTokens: cacheReadTokens, - CacheCreation5mTokens: cacheCreation5m, - CacheCreation1hTokens: cacheCreation1h, - ImageOutputTokens: imageOutputTokens, - ImageOutputCost: imageOutputCost, - InputCost: inputCost, - OutputCost: outputCost, - CacheCreationCost: cacheCreationCost, - CacheReadCost: cacheReadCost, - TotalCost: totalCost, - ActualCost: actualCost, - RateMultiplier: rateMultiplier, - AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier), - BillingType: int8(billingType), - RequestType: service.RequestTypeFromInt16(requestTypeRaw), - ImageCount: imageCount, - CacheTTLOverridden: cacheTTLOverridden, - CreatedAt: createdAt, - } - // 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。 - log.Stream = stream - log.OpenAIWSMode = openaiWSMode - log.RequestType = log.EffectiveRequestType() - log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode) - - if requestID.Valid { - log.RequestID = requestID.String - } - if groupID.Valid { - value := groupID.Int64 - log.GroupID = &value - } - if subscriptionID.Valid { - value := subscriptionID.Int64 - log.SubscriptionID = &value - } - if durationMs.Valid { - value := int(durationMs.Int64) - log.DurationMs = &value - } - if firstTokenMs.Valid { - value := int(firstTokenMs.Int64) - log.FirstTokenMs = &value - } - if userAgent.Valid { - log.UserAgent = &userAgent.String - } - if ipAddress.Valid { - log.IPAddress = &ipAddress.String - } - if imageSize.Valid { - log.ImageSize = &imageSize.String - } - if serviceTier.Valid { - log.ServiceTier = &serviceTier.String - } - if reasoningEffort.Valid { - log.ReasoningEffort = &reasoningEffort.String - } - if inboundEndpoint.Valid { - log.InboundEndpoint = &inboundEndpoint.String - } - if upstreamEndpoint.Valid { - log.UpstreamEndpoint = &upstreamEndpoint.String - } - if upstreamModel.Valid { - log.UpstreamModel = &upstreamModel.String - } - if channelID.Valid { - value := channelID.Int64 - log.ChannelID = &value - } - if modelMappingChain.Valid { - log.ModelMappingChain = &modelMappingChain.String - } - if billingTier.Valid { - log.BillingTier = &billingTier.String - } - if billingMode.Valid { - log.BillingMode = &billingMode.String - } - if accountStatsCost.Valid { - log.AccountStatsCost = &accountStatsCost.Float64 - } - - return log, nil -} - func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) { results := make([]TrendDataPoint, 0) for rows.Next() { diff --git a/backend/internal/repository/usage_log_stats.go b/backend/internal/repository/usage_log_stats.go new file mode 100644 index 00000000000..53267ed91a2 --- /dev/null +++ b/backend/internal/repository/usage_log_stats.go @@ -0,0 +1,1129 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +// UserStats 用户使用统计 +type UserStats struct { + TotalRequests int64 `json:"total_requests"` + TotalTokens int64 `json:"total_tokens"` + TotalCost float64 `json:"total_cost"` + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + CacheReadTokens int64 `json:"cache_read_tokens"` +} + +func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, startTime, endTime time.Time) (*UserStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(actual_cost), 0) as total_cost, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 + ` + + stats := &UserStats{} + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{userID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalTokens, + &stats.TotalCost, + &stats.InputTokens, + &stats.OutputTokens, + &stats.CacheReadTokens, + ); err != nil { + return nil, err + } + return stats, nil +} + +// GetUserStatsAggregated returns aggregated usage statistics for a user using database-level aggregation +func (r *usageLogRepository) GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 + ` + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{userID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetAPIKeyStatsAggregated returns aggregated usage statistics for an API key using database-level aggregation +func (r *usageLogRepository) GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 + ` + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{apiKeyID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetAccountStatsAggregated 使用 SQL 聚合统计账号使用数据 +// +// 性能优化说明: +// 原实现先查询所有日志记录,再在应用层循环计算统计值: +// 1. 需要传输大量数据到应用层 +// 2. 应用层循环计算增加 CPU 和内存开销 +// +// 新实现使用 SQL 聚合函数: +// 1. 在数据库层完成 COUNT/SUM/AVG 计算 +// 2. 只返回单行聚合结果,大幅减少数据传输量 +// 3. 利用数据库索引优化聚合查询性能 +func (r *usageLogRepository) GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 + ` + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{accountID, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetModelStatsAggregated 使用 SQL 聚合统计模型使用数据 +// 性能优化:数据库层聚合计算,避免应用层循环统计 +func (r *usageLogRepository) GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error) { + query := fmt.Sprintf(` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE %s = $1 AND created_at >= $2 AND created_at < $3 + `, rawUsageLogModelColumn) + + var stats usagestats.UsageStats + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{modelName, startTime, endTime}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens + return &stats, nil +} + +// GetDailyStatsAggregated 使用 SQL 聚合统计用户的每日使用数据 +// 性能优化:使用 GROUP BY 在数据库层按日期分组聚合,避免应用层循环分组统计 +func (r *usageLogRepository) GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (result []map[string]any, err error) { + tzName := resolveUsageStatsTimezone() + query := ` + SELECT + -- 使用应用时区分组,避免数据库会话时区导致日边界偏移。 + TO_CHAR(created_at AT TIME ZONE $4, 'YYYY-MM-DD') as date, + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(COALESCE(duration_ms, 0)), 0) as avg_duration_ms + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 + GROUP BY 1 + ORDER BY 1 + ` + + rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime, tzName) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + result = nil + } + }() + + result = make([]map[string]any, 0) + for rows.Next() { + var ( + date string + totalRequests int64 + totalInputTokens int64 + totalOutputTokens int64 + totalCacheTokens int64 + totalCost float64 + totalActualCost float64 + avgDurationMs float64 + ) + if err = rows.Scan( + &date, + &totalRequests, + &totalInputTokens, + &totalOutputTokens, + &totalCacheTokens, + &totalCost, + &totalActualCost, + &avgDurationMs, + ); err != nil { + return nil, err + } + result = append(result, map[string]any{ + "date": date, + "total_requests": totalRequests, + "total_input_tokens": totalInputTokens, + "total_output_tokens": totalOutputTokens, + "total_cache_tokens": totalCacheTokens, + "total_tokens": totalInputTokens + totalOutputTokens + totalCacheTokens, + "total_cost": totalCost, + "total_actual_cost": totalActualCost, + "average_duration_ms": avgDurationMs, + }) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date +func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { + dateFormat := safeDateFormat(granularity) + + query := fmt.Sprintf(` + WITH top_keys AS ( + SELECT api_key_id + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + GROUP BY api_key_id + ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC + LIMIT $3 + ) + SELECT + TO_CHAR(u.created_at, '%s') as date, + u.api_key_id, + COALESCE(k.name, '') as key_name, + COUNT(*) as requests, + COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens + FROM usage_logs u + LEFT JOIN api_keys k ON u.api_key_id = k.id + WHERE u.api_key_id IN (SELECT api_key_id FROM top_keys) + AND u.created_at >= $4 AND u.created_at < $5 + GROUP BY date, u.api_key_id, k.name + ORDER BY date ASC, tokens DESC + `, dateFormat) + + rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime) + if err != nil { + return nil, err + } + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]APIKeyUsageTrendPoint, 0) + for rows.Next() { + var row APIKeyUsageTrendPoint + if err = rows.Scan(&row.Date, &row.APIKeyID, &row.KeyName, &row.Requests, &row.Tokens); err != nil { + return nil, err + } + results = append(results, row) + } + if err = rows.Err(); err != nil { + return nil, err + } + + return results, nil +} + +// GetUserUsageTrend returns usage trend data grouped by user and date +func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { + dateFormat := safeDateFormat(granularity) + + query := fmt.Sprintf(` + WITH top_users AS ( + SELECT user_id + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + GROUP BY user_id + ORDER BY SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) DESC + LIMIT $3 + ) + SELECT + TO_CHAR(u.created_at, '%s') as date, + u.user_id, + COALESCE(us.email, '') as email, + COALESCE(us.username, '') as username, + COUNT(*) as requests, + COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens, + COALESCE(SUM(u.total_cost), 0) as cost, + COALESCE(SUM(u.actual_cost), 0) as actual_cost + FROM usage_logs u + LEFT JOIN users us ON u.user_id = us.id + WHERE u.user_id IN (SELECT user_id FROM top_users) + AND u.created_at >= $4 AND u.created_at < $5 + GROUP BY date, u.user_id, us.email, us.username + ORDER BY date ASC, tokens DESC + `, dateFormat) + + rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit, startTime, endTime) + if err != nil { + return nil, err + } + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]UserUsageTrendPoint, 0) + for rows.Next() { + var row UserUsageTrendPoint + if err = rows.Scan(&row.Date, &row.UserID, &row.Email, &row.Username, &row.Requests, &row.Tokens, &row.Cost, &row.ActualCost); err != nil { + return nil, err + } + results = append(results, row) + } + if err = rows.Err(); err != nil { + return nil, err + } + + return results, nil +} + +// GetUserSpendingRanking returns user spending ranking aggregated within the time range. +func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (result *UserSpendingRankingResponse, err error) { + if limit <= 0 { + limit = 12 + } + + query := ` + WITH user_spend AS ( + SELECT + u.user_id, + COALESCE(us.email, '') as email, + COALESCE(SUM(u.actual_cost), 0) as actual_cost, + COUNT(*) as requests, + COALESCE(SUM(u.input_tokens + u.output_tokens + u.cache_creation_tokens + u.cache_read_tokens), 0) as tokens + FROM usage_logs u + LEFT JOIN users us ON u.user_id = us.id + WHERE u.created_at >= $1 AND u.created_at < $2 + GROUP BY u.user_id, us.email + ), + ranked AS ( + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + COALESCE(SUM(actual_cost) OVER (), 0) as total_actual_cost, + COALESCE(SUM(requests) OVER (), 0) as total_requests, + COALESCE(SUM(tokens) OVER (), 0) as total_tokens + FROM user_spend + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + LIMIT $3 + ) + SELECT + user_id, + email, + actual_cost, + requests, + tokens, + total_actual_cost, + total_requests, + total_tokens + FROM ranked + ORDER BY actual_cost DESC, tokens DESC, user_id ASC + ` + + rows, err := r.sql.QueryContext(ctx, query, startTime, endTime, limit) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + result = nil + } + }() + + ranking := make([]UserSpendingRankingItem, 0) + totalActualCost := 0.0 + totalRequests := int64(0) + totalTokens := int64(0) + for rows.Next() { + var row UserSpendingRankingItem + if err = rows.Scan(&row.UserID, &row.Email, &row.ActualCost, &row.Requests, &row.Tokens, &totalActualCost, &totalRequests, &totalTokens); err != nil { + return nil, err + } + ranking = append(ranking, row) + } + if err = rows.Err(); err != nil { + return nil, err + } + + return &UserSpendingRankingResponse{ + Ranking: ranking, + TotalActualCost: totalActualCost, + TotalRequests: totalRequests, + TotalTokens: totalTokens, + }, nil +} + +// GetUserDashboardStats 获取用户专属的仪表盘统计 +func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { + stats := &UserDashboardStats{} + today := timezone.Today() + + // API Key 统计 + if err := scanSingleRow( + ctx, + r.sql, + "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND deleted_at IS NULL", + []any{userID}, + &stats.TotalAPIKeys, + ); err != nil { + return nil, err + } + if err := scanSingleRow( + ctx, + r.sql, + "SELECT COUNT(*) FROM api_keys WHERE user_id = $1 AND status = $2 AND deleted_at IS NULL", + []any{userID, service.StatusActive}, + &stats.ActiveAPIKeys, + ); err != nil { + return nil, err + } + + // 累计 Token 统计 + totalStatsQuery := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms + FROM usage_logs + WHERE user_id = $1 + ` + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + []any{userID}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + + // 今日 Token 统计 + todayStatsQuery := ` + SELECT + COUNT(*) as today_requests, + COALESCE(SUM(input_tokens), 0) as today_input_tokens, + COALESCE(SUM(output_tokens), 0) as today_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as today_cost, + COALESCE(SUM(actual_cost), 0) as today_actual_cost + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 + ` + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{userID, today}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + ); err != nil { + return nil, err + } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + + // 性能指标:RPM 和 TPM(最近1分钟,仅统计该用户的请求) + rpm, tpm, err := r.getPerformanceStats(ctx, userID) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + +// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤) +func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) { + stats := &UserDashboardStats{} + today := timezone.Today() + + // API Key 维度不需要统计 key 数量,设为 1 + stats.TotalAPIKeys = 1 + stats.ActiveAPIKeys = 1 + + // 累计 Token 统计 + totalStatsQuery := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms + FROM usage_logs + WHERE api_key_id = $1 + ` + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + []any{apiKeyID}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + + // 今日 Token 统计 + todayStatsQuery := ` + SELECT + COUNT(*) as today_requests, + COALESCE(SUM(input_tokens), 0) as today_input_tokens, + COALESCE(SUM(output_tokens), 0) as today_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as today_cost, + COALESCE(SUM(actual_cost), 0) as today_actual_cost + FROM usage_logs + WHERE api_key_id = $1 AND created_at >= $2 + ` + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{apiKeyID, today}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + ); err != nil { + return nil, err + } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + + // 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤) + rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + +// GetUserUsageTrendByUserID 获取指定用户的使用趋势 +func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { + dateFormat := safeDateFormat(granularity) + + query := fmt.Sprintf(` + SELECT + TO_CHAR(created_at, '%s') as date, + COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(total_cost), 0) as cost, + COALESCE(SUM(actual_cost), 0) as actual_cost + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 + GROUP BY date + ORDER BY date ASC + `, dateFormat) + + rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) + if err != nil { + return nil, err + } + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results, err = scanTrendRows(rows) + if err != nil { + return nil, err + } + return results, nil +} + +// GetUserModelStats 获取指定用户的模型统计 +func (r *usageLogRepository) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) (results []ModelStat, err error) { + query := ` + SELECT + model, + COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(total_cost), 0) as cost, + COALESCE(SUM(actual_cost), 0) as actual_cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as account_cost + FROM usage_logs + WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 + GROUP BY model + ORDER BY total_tokens DESC + ` + + rows, err := r.sql.QueryContext(ctx, query, userID, startTime, endTime) + if err != nil { + return nil, err + } + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results, err = scanModelStatsRows(rows) + if err != nil { + return nil, err + } + return results, nil +} + +// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { + result := make(map[int64]*BatchUserUsageStats) + normalizedUserIDs := normalizePositiveInt64IDs(userIDs) + if len(normalizedUserIDs) == 0 { + return result, nil + } + + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + + for _, id := range normalizedUserIDs { + result[id] = &BatchUserUsageStats{UserID: id} + } + + query := ` + SELECT + user_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost + FROM usage_logs + WHERE user_id = ANY($1) + AND created_at >= LEAST($2, $4) + GROUP BY user_id + ` + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today) + if err != nil { + return nil, err + } + for rows.Next() { + var userID int64 + var total float64 + var todayTotal float64 + if err := rows.Scan(&userID, &total, &todayTotal); err != nil { + _ = rows.Close() + return nil, err + } + if stats, ok := result[userID]; ok { + stats.TotalActualCost = total + stats.TodayActualCost = todayTotal + } + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { + result := make(map[int64]*BatchAPIKeyUsageStats) + normalizedAPIKeyIDs := normalizePositiveInt64IDs(apiKeyIDs) + if len(normalizedAPIKeyIDs) == 0 { + return result, nil + } + + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + + for _, id := range normalizedAPIKeyIDs { + result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} + } + + query := ` + SELECT + api_key_id, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost + FROM usage_logs + WHERE api_key_id = ANY($1) + AND created_at >= LEAST($2, $4) + GROUP BY api_key_id + ` + today := timezone.Today() + rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedAPIKeyIDs), startTime, endTime, today) + if err != nil { + return nil, err + } + for rows.Next() { + var apiKeyID int64 + var total float64 + var todayTotal float64 + if err := rows.Scan(&apiKeyID, &total, &todayTotal); err != nil { + _ = rows.Close() + return nil, err + } + if stats, ok := result[apiKeyID]; ok { + stats.TotalActualCost = total + stats.TodayActualCost = todayTotal + } + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + + return result, nil +} + +// GetUsageTrendWithFilters returns usage trend data with optional filters +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { + if shouldUsePreaggregatedTrend(granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) { + aggregated, aggregatedErr := r.getUsageTrendFromAggregates(ctx, startTime, endTime, granularity) + if aggregatedErr == nil && len(aggregated) > 0 { + return aggregated, nil + } + } + + dateFormat := safeDateFormat(granularity) + + query := fmt.Sprintf(` + SELECT + TO_CHAR(created_at, '%s') as date, + COUNT(*) as requests, + COALESCE(SUM(input_tokens), 0) as input_tokens, + COALESCE(SUM(output_tokens), 0) as output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as cache_read_tokens, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(total_cost), 0) as cost, + COALESCE(SUM(actual_cost), 0) as actual_cost + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + `, dateFormat) + + args := []any{startTime, endTime} + if userID > 0 { + query += fmt.Sprintf(" AND user_id = $%d", len(args)+1) + args = append(args, userID) + } + if apiKeyID > 0 { + query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1) + args = append(args, apiKeyID) + } + if accountID > 0 { + query += fmt.Sprintf(" AND account_id = $%d", len(args)+1) + args = append(args, accountID) + } + if groupID > 0 { + query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) + args = append(args, groupID) + } + query, args = appendRawUsageLogModelQueryFilter(query, args, model) + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) + if billingType != nil { + query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } + query += " GROUP BY date ORDER BY date ASC" + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + // 保持主错误优先;仅在无错误时回传 Close 失败。 + // 同时清空返回值,避免误用不完整结果。 + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results, err = scanTrendRows(rows) + if err != nil { + return nil, err + } + return results, nil +} + +// GetModelStatsWithFilters returns model statistics with optional filters +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { + return r.getModelStatsWithFiltersBySource(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType, usagestats.ModelSourceRequested) +} + +// GetGroupStatsWithFilters returns group usage statistics with optional filters +func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []usagestats.GroupStat, err error) { + query := ` + SELECT + COALESCE(ul.group_id, 0) as group_id, + COALESCE(g.name, '') as group_name, + COUNT(*) as requests, + COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(ul.total_cost), 0) as cost, + COALESCE(SUM(ul.actual_cost), 0) as actual_cost, + COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost + FROM usage_logs ul + LEFT JOIN groups g ON g.id = ul.group_id + WHERE ul.created_at >= $1 AND ul.created_at < $2 + ` + + args := []any{startTime, endTime} + if userID > 0 { + query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) + args = append(args, userID) + } + if apiKeyID > 0 { + query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) + args = append(args, apiKeyID) + } + if accountID > 0 { + query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) + args = append(args, accountID) + } + if groupID > 0 { + query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) + args = append(args, groupID) + } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) + if billingType != nil { + query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } + query += " GROUP BY ul.group_id, g.name ORDER BY total_tokens DESC" + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]usagestats.GroupStat, 0) + for rows.Next() { + var row usagestats.GroupStat + if err := rows.Scan( + &row.GroupID, + &row.GroupName, + &row.Requests, + &row.TotalTokens, + &row.Cost, + &row.ActualCost, + &row.AccountCost, + ); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +// GetUserBreakdownStats returns per-user usage breakdown within a specific dimension. +func (r *usageLogRepository) GetUserBreakdownStats(ctx context.Context, startTime, endTime time.Time, dim usagestats.UserBreakdownDimension, limit int) (results []usagestats.UserBreakdownItem, err error) { + query := ` + SELECT + COALESCE(ul.user_id, 0) as user_id, + COALESCE(u.email, '') as email, + COUNT(*) as requests, + COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(ul.total_cost), 0) as cost, + COALESCE(SUM(ul.actual_cost), 0) as actual_cost, + COALESCE(SUM(COALESCE(ul.account_stats_cost, ul.total_cost) * COALESCE(ul.account_rate_multiplier, 1)), 0) as account_cost + FROM usage_logs ul + LEFT JOIN users u ON u.id = ul.user_id + WHERE ul.created_at >= $1 AND ul.created_at < $2 + ` + args := []any{startTime, endTime} + + if dim.GroupID > 0 { + query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) + args = append(args, dim.GroupID) + } + if dim.Model != "" { + query += fmt.Sprintf(" AND %s = $%d", resolveModelDimensionExpression(dim.ModelType), len(args)+1) + args = append(args, dim.Model) + } + if dim.Endpoint != "" { + col := resolveEndpointColumn(dim.EndpointType) + query += fmt.Sprintf(" AND %s = $%d", col, len(args)+1) + args = append(args, dim.Endpoint) + } + if dim.UserID > 0 { + query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) + args = append(args, dim.UserID) + } + if dim.APIKeyID > 0 { + query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) + args = append(args, dim.APIKeyID) + } + if dim.AccountID > 0 { + query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) + args = append(args, dim.AccountID) + } + if dim.RequestType != nil { + query += fmt.Sprintf(" AND ul.request_type = $%d", len(args)+1) + args = append(args, *dim.RequestType) + } + if dim.Stream != nil { + query += fmt.Sprintf(" AND ul.stream = $%d", len(args)+1) + args = append(args, *dim.Stream) + } + if dim.BillingType != nil { + query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) + args = append(args, *dim.BillingType) + } + + query += " GROUP BY ul.user_id, u.email ORDER BY actual_cost DESC" + if limit > 0 { + query += fmt.Sprintf(" LIMIT %d", limit) + } + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer func() { + if closeErr := rows.Close(); closeErr != nil && err == nil { + err = closeErr + results = nil + } + }() + + results = make([]usagestats.UserBreakdownItem, 0) + for rows.Next() { + var row usagestats.UserBreakdownItem + if err := rows.Scan( + &row.UserID, + &row.Email, + &row.Requests, + &row.TotalTokens, + &row.Cost, + &row.ActualCost, + &row.AccountCost, + ); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} + +// GetAllGroupUsageSummary returns today's and cumulative actual_cost for every group. +// todayStart is the start-of-day in the caller's timezone (UTC-based). +// TODO(perf): This query scans ALL usage_logs rows for total_cost aggregation. +// When usage_logs exceeds ~1M rows, consider adding a short-lived cache (30s) +// or a materialized view / pre-aggregation table for cumulative costs. +func (r *usageLogRepository) GetAllGroupUsageSummary(ctx context.Context, todayStart time.Time) ([]usagestats.GroupUsageSummary, error) { + query := ` + SELECT + g.id AS group_id, + COALESCE(SUM(ul.actual_cost), 0) AS total_cost, + COALESCE(SUM(CASE WHEN ul.created_at >= $1 THEN ul.actual_cost ELSE 0 END), 0) AS today_cost + FROM groups g + LEFT JOIN usage_logs ul ON ul.group_id = g.id + GROUP BY g.id + ` + + rows, err := r.sql.QueryContext(ctx, query, todayStart) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + var results []usagestats.GroupUsageSummary + for rows.Next() { + var row usagestats.GroupUsageSummary + if err := rows.Scan(&row.GroupID, &row.TotalCost, &row.TodayCost); err != nil { + return nil, err + } + results = append(results, row) + } + if err := rows.Err(); err != nil { + return nil, err + } + return results, nil +} diff --git a/backend/internal/repository/usage_log_write.go b/backend/internal/repository/usage_log_write.go new file mode 100644 index 00000000000..efe1f42e802 --- /dev/null +++ b/backend/internal/repository/usage_log_write.go @@ -0,0 +1,1074 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type usageLogCreateRequest struct { + log *service.UsageLog + prepared usageLogInsertPrepared + shared *usageLogCreateShared + resultCh chan usageLogCreateResult +} + +type usageLogCreateResult struct { + inserted bool + err error +} + +type usageLogBestEffortRequest struct { + prepared usageLogInsertPrepared + apiKeyID int64 + resultCh chan error +} + +func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) { + if log == nil { + return false, nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + return r.createSingle(ctx, tx.Client(), log) + } + requestID := strings.TrimSpace(log.RequestID) + if requestID == "" { + return r.createSingle(ctx, r.sql, log) + } + log.RequestID = requestID + return r.createBatched(ctx, log) +} + +func (r *usageLogRepository) CreateBestEffort(ctx context.Context, log *service.UsageLog) error { + if log == nil { + return nil + } + + if tx := dbent.TxFromContext(ctx); tx != nil { + _, err := r.createSingle(ctx, tx.Client(), log) + return err + } + if r.db == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + r.ensureBestEffortBatcher() + if r.bestEffortBatchCh == nil { + _, err := r.createSingle(ctx, r.sql, log) + return err + } + + req := usageLogBestEffortRequest{ + prepared: prepareUsageLogInsert(log), + apiKeyID: log.APIKeyID, + resultCh: make(chan error, 1), + } + if key, ok := r.bestEffortRecentKey(req.prepared.requestID, req.apiKeyID); ok { + if _, exists := r.bestEffortRecent.Get(key); exists { + return nil + } + } + + select { + case r.bestEffortBatchCh <- req: + case <-ctx.Done(): + return service.MarkUsageLogCreateDropped(ctx.Err()) + default: + return service.MarkUsageLogCreateDropped(errors.New("usage log best-effort queue full")) + } + + select { + case err := <-req.resultCh: + return err + case <-ctx.Done(): + return service.MarkUsageLogCreateDropped(ctx.Err()) + } +} + +func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, log *service.UsageLog) (bool, error) { + prepared := prepareUsageLogInsert(log) + if sqlq == nil { + sqlq = r.sql + } + if ctx != nil && ctx.Err() != nil { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } + + query := ` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, $16, $17, + $18, $19, $20, $21, $22, $23, + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id, created_at + ` + + if err := scanSingleRow(ctx, sqlq, query, prepared.args, &log.ID, &log.CreatedAt); err != nil { + if errors.Is(err, sql.ErrNoRows) && prepared.requestID != "" { + selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" + if err := scanSingleRow(ctx, sqlq, selectQuery, []any{prepared.requestID, log.APIKeyID}, &log.ID, &log.CreatedAt); err != nil { + return false, err + } + log.RateMultiplier = prepared.rateMultiplier + return false, nil + } else { + return false, err + } + } + log.RateMultiplier = prepared.rateMultiplier + return true, nil +} + +func (r *usageLogRepository) createBatched(ctx context.Context, log *service.UsageLog) (bool, error) { + if r.db == nil { + return r.createSingle(ctx, r.sql, log) + } + r.ensureCreateBatcher() + if r.createBatchCh == nil { + return r.createSingle(ctx, r.sql, log) + } + + req := usageLogCreateRequest{ + log: log, + prepared: prepareUsageLogInsert(log), + shared: &usageLogCreateShared{}, + resultCh: make(chan usageLogCreateResult, 1), + } + + select { + case r.createBatchCh <- req: + case <-ctx.Done(): + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + default: + return false, service.MarkUsageLogCreateNotPersisted(errors.New("usage log create batch queue full")) + } + + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-ctx.Done(): + if req.shared != nil && req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateCanceled) { + return false, service.MarkUsageLogCreateNotPersisted(ctx.Err()) + } + timer := time.NewTimer(usageLogCreateCancelWait) + defer timer.Stop() + select { + case res := <-req.resultCh: + return res.inserted, res.err + case <-timer.C: + return false, ctx.Err() + } + } +} + +func (r *usageLogRepository) ensureCreateBatcher() { + if r == nil || r.db == nil || r.createBatchCh != nil { + return + } + r.createBatchOnce.Do(func() { + r.createBatchCh = make(chan usageLogCreateRequest, usageLogCreateBatchQueueCap) + go r.runCreateBatcher(r.db) + }) +} + +func (r *usageLogRepository) ensureBestEffortBatcher() { + if r == nil || r.db == nil || r.bestEffortBatchCh != nil { + return + } + r.bestEffortBatchOnce.Do(func() { + r.bestEffortBatchCh = make(chan usageLogBestEffortRequest, usageLogBestEffortBatchQueueCap) + go r.runBestEffortBatcher(r.db) + }) +} + +func (r *usageLogRepository) runCreateBatcher(db *sql.DB) { + for { + first, ok := <-r.createBatchCh + if !ok { + return + } + + batch := make([]usageLogCreateRequest, 0, usageLogCreateBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogCreateBatchWindow) + batchLoop: + for len(batch) < usageLogCreateBatchMaxSize { + select { + case req, ok := <-r.createBatchCh: + if !ok { + break batchLoop + } + batch = append(batch, req) + case <-timer.C: + break batchLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushCreateBatch(db, batch) + } +} + +func (r *usageLogRepository) runBestEffortBatcher(db *sql.DB) { + for { + first, ok := <-r.bestEffortBatchCh + if !ok { + return + } + + batch := make([]usageLogBestEffortRequest, 0, usageLogBestEffortBatchMaxSize) + batch = append(batch, first) + + timer := time.NewTimer(usageLogBestEffortBatchWindow) + bestEffortLoop: + for len(batch) < usageLogBestEffortBatchMaxSize { + select { + case req, ok := <-r.bestEffortBatchCh: + if !ok { + break bestEffortLoop + } + batch = append(batch, req) + case <-timer.C: + break bestEffortLoop + } + } + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + + r.flushBestEffortBatch(db, batch) + } +} + +func (r *usageLogRepository) flushCreateBatch(db *sql.DB, batch []usageLogCreateRequest) { + if len(batch) == 0 { + return + } + + uniqueOrder := make([]string, 0, len(batch)) + preparedByKey := make(map[string]usageLogInsertPrepared, len(batch)) + requestsByKey := make(map[string][]usageLogCreateRequest, len(batch)) + fallback := make([]usageLogCreateRequest, 0) + + for _, req := range batch { + if req.log == nil { + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + continue + } + if req.shared != nil && !req.shared.state.CompareAndSwap(usageLogCreateStateQueued, usageLogCreateStateProcessing) { + if req.shared.state.Load() == usageLogCreateStateCanceled { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: service.MarkUsageLogCreateNotPersisted(context.Canceled), + }) + continue + } + } + prepared := req.prepared + if prepared.requestID == "" { + fallback = append(fallback, req) + continue + } + key := usageLogBatchKey(prepared.requestID, req.log.APIKeyID) + if _, exists := requestsByKey[key]; !exists { + uniqueOrder = append(uniqueOrder, key) + preparedByKey[key] = prepared + } + requestsByKey[key] = append(requestsByKey[key], req) + } + + if len(uniqueOrder) > 0 { + insertedMap, stateMap, safeFallback, err := r.batchInsertUsageLogs(db, uniqueOrder, preparedByKey) + if err != nil { + if safeFallback { + for _, key := range uniqueOrder { + fallback = append(fallback, requestsByKey[key]...) + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, hasState := stateMap[key] + inserted := insertedMap[key] + for idx, req := range reqs { + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + if hasState { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + } + switch { + case inserted && idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: true, err: nil}) + case inserted: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case hasState: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + case idx == 0: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: err}) + default: + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: false, err: nil}) + } + } + } + } + } else { + for _, key := range uniqueOrder { + reqs := requestsByKey[key] + state, ok := stateMap[key] + if !ok { + for _, req := range reqs { + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: false, + err: fmt.Errorf("usage log batch state missing for key=%s", key), + }) + } + continue + } + for idx, req := range reqs { + req.log.ID = state.ID + req.log.CreatedAt = state.CreatedAt + req.log.RateMultiplier = preparedByKey[key].rateMultiplier + completeUsageLogCreateRequest(req, usageLogCreateResult{ + inserted: idx == 0 && insertedMap[key], + err: nil, + }) + } + } + } + } + + if len(fallback) == 0 { + return + } + + for _, req := range fallback { + fallbackCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + inserted, err := r.createSingle(fallbackCtx, db, req.log) + cancel() + completeUsageLogCreateRequest(req, usageLogCreateResult{inserted: inserted, err: err}) + } +} + +func (r *usageLogRepository) flushBestEffortBatch(db *sql.DB, batch []usageLogBestEffortRequest) { + if len(batch) == 0 { + return + } + + type bestEffortGroup struct { + prepared usageLogInsertPrepared + apiKeyID int64 + key string + reqs []usageLogBestEffortRequest + } + + groupsByKey := make(map[string]*bestEffortGroup, len(batch)) + groupOrder := make([]*bestEffortGroup, 0, len(batch)) + preparedList := make([]usageLogInsertPrepared, 0, len(batch)) + + for idx, req := range batch { + prepared := req.prepared + key := fmt.Sprintf("__best_effort_%d", idx) + if prepared.requestID != "" { + key = usageLogBatchKey(prepared.requestID, req.apiKeyID) + } + group, exists := groupsByKey[key] + if !exists { + group = &bestEffortGroup{ + prepared: prepared, + apiKeyID: req.apiKeyID, + key: key, + } + groupsByKey[key] = group + groupOrder = append(groupOrder, group) + preparedList = append(preparedList, prepared) + } + group.reqs = append(group.reqs, req) + } + + if len(preparedList) == 0 { + for _, req := range batch { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBestEffortInsertQuery(preparedList) + if _, err := db.ExecContext(ctx, query, args...); err != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort batch insert failed: %v", err) + for _, group := range groupOrder { + singleErr := execUsageLogInsertNoResult(ctx, db, group.prepared) + if singleErr != nil { + logger.LegacyPrintf("repository.usage_log", "best-effort single fallback insert failed: %v", singleErr) + } else if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, singleErr) + } + } + return + } + for _, group := range groupOrder { + if group.prepared.requestID != "" && r != nil && r.bestEffortRecent != nil { + r.bestEffortRecent.SetDefault(group.key, struct{}{}) + } + for _, req := range group.reqs { + sendUsageLogBestEffortResult(req.resultCh, nil) + } + } +} + +func (r *usageLogRepository) batchInsertUsageLogs(db *sql.DB, keys []string, preparedByKey map[string]usageLogInsertPrepared) (map[string]bool, map[string]usageLogBatchState, bool, error) { + if len(keys) == 0 { + return map[string]bool{}, map[string]usageLogBatchState{}, false, nil + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + query, args := buildUsageLogBatchInsertQuery(keys, preparedByKey) + var payload []byte + if err := db.QueryRowContext(ctx, query, args...).Scan(&payload); err != nil { + return nil, nil, true, err + } + var rows []usageLogBatchRow + if err := json.Unmarshal(payload, &rows); err != nil { + return nil, nil, false, err + } + insertedMap := make(map[string]bool, len(keys)) + stateMap := make(map[string]usageLogBatchState, len(keys)) + for _, row := range rows { + key := usageLogBatchKey(row.RequestID, row.APIKeyID) + insertedMap[key] = row.Inserted + stateMap[key] = usageLogBatchState{ + ID: row.ID, + CreatedAt: row.CreatedAt, + } + } + if len(stateMap) != len(keys) { + return insertedMap, stateMap, false, fmt.Errorf("usage log batch state count mismatch: got=%d want=%d", len(stateMap), len(keys)) + } + return insertedMap, stateMap, false, nil +} + +func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + input_idx, + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(keys)*46) + argPos := 1 + for idx, key := range keys { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + args = append(args, idx) + argPos++ + prepared := preparedByKey[key] + for i := 0; i < len(prepared.args); i++ { + _, _ = query.WriteString(",") + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + _, _ = query.WriteString(` + ), + inserted AS ( + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING request_id, api_key_id, id, created_at + ), + resolved AS ( + SELECT + input.input_idx, + input.request_id, + input.api_key_id, + COALESCE(inserted.id, existing.id) AS id, + COALESCE(inserted.created_at, existing.created_at) AS created_at, + (inserted.id IS NOT NULL) AS inserted + FROM input + LEFT JOIN inserted + ON inserted.request_id = input.request_id + AND inserted.api_key_id = input.api_key_id + LEFT JOIN usage_logs existing + ON existing.request_id = input.request_id + AND existing.api_key_id = input.api_key_id + ) + SELECT COALESCE( + json_agg( + json_build_object( + 'request_id', resolved.request_id, + 'api_key_id', resolved.api_key_id, + 'id', resolved.id, + 'created_at', resolved.created_at, + 'inserted', resolved.inserted + ) + ORDER BY resolved.input_idx + ), + '[]'::json + ) + FROM resolved + `) + return query.String(), args +} + +func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) (string, []any) { + var query strings.Builder + _, _ = query.WriteString(` + WITH input ( + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + ) AS (VALUES `) + + args := make([]any, 0, len(preparedList)*46) + argPos := 1 + for idx, prepared := range preparedList { + if idx > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("(") + for i := 0; i < len(prepared.args); i++ { + if i > 0 { + _, _ = query.WriteString(",") + } + _, _ = query.WriteString("$") + _, _ = query.WriteString(strconv.Itoa(argPos)) + if i < len(usageLogInsertArgTypes) { + _, _ = query.WriteString("::") + _, _ = query.WriteString(usageLogInsertArgTypes[i]) + } + argPos++ + } + _, _ = query.WriteString(")") + args = append(args, prepared.args...) + } + + _, _ = query.WriteString(` + ) + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + ) + SELECT + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + FROM input + ON CONFLICT (request_id, api_key_id) DO NOTHING + `) + + return query.String(), args +} + +func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared usageLogInsertPrepared) error { + _, err := sqlq.ExecContext(ctx, ` + INSERT INTO usage_logs ( + user_id, + api_key_id, + account_id, + request_id, + model, + requested_model, + upstream_model, + group_id, + subscription_id, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + cache_creation_5m_tokens, + cache_creation_1h_tokens, + image_output_tokens, + image_output_cost, + input_cost, + output_cost, + cache_creation_cost, + cache_read_cost, + total_cost, + actual_cost, + rate_multiplier, + account_rate_multiplier, + billing_type, + request_type, + stream, + openai_ws_mode, + duration_ms, + first_token_ms, + user_agent, + ip_address, + image_count, + image_size, + service_tier, + reasoning_effort, + inbound_endpoint, + upstream_endpoint, + cache_ttl_overridden, + channel_id, + model_mapping_chain, + billing_tier, + billing_mode, + account_stats_cost, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, $6, $7, + $8, $9, + $10, $11, $12, $13, + $14, $15, $16, $17, + $18, $19, $20, $21, $22, $23, + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + `, prepared.args...) + return err +} + +func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { + createdAt := log.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now() + } + + requestID := strings.TrimSpace(log.RequestID) + log.RequestID = requestID + + rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) + + groupID := nullInt64(log.GroupID) + subscriptionID := nullInt64(log.SubscriptionID) + duration := nullInt(log.DurationMs) + firstToken := nullInt(log.FirstTokenMs) + userAgent := nullString(log.UserAgent) + ipAddress := nullString(log.IPAddress) + imageSize := nullString(log.ImageSize) + serviceTier := nullString(log.ServiceTier) + reasoningEffort := nullString(log.ReasoningEffort) + inboundEndpoint := nullString(log.InboundEndpoint) + upstreamEndpoint := nullString(log.UpstreamEndpoint) + channelID := nullInt64(log.ChannelID) + modelMappingChain := nullString(log.ModelMappingChain) + billingTier := nullString(log.BillingTier) + billingMode := nullString(log.BillingMode) + requestedModel := strings.TrimSpace(log.RequestedModel) + if requestedModel == "" { + requestedModel = strings.TrimSpace(log.Model) + } + upstreamModel := nullString(log.UpstreamModel) + + var requestIDArg any + if requestID != "" { + requestIDArg = requestID + } + + return usageLogInsertPrepared{ + createdAt: createdAt, + requestID: requestID, + rateMultiplier: rateMultiplier, + requestType: requestType, + args: []any{ + log.UserID, + log.APIKeyID, + log.AccountID, + requestIDArg, + log.Model, + nullString(&requestedModel), + upstreamModel, + groupID, + subscriptionID, + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.ImageOutputTokens, + log.ImageOutputCost, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + rateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + requestType, + log.Stream, + log.OpenAIWSMode, + duration, + firstToken, + userAgent, + ipAddress, + log.ImageCount, + imageSize, + serviceTier, + reasoningEffort, + inboundEndpoint, + upstreamEndpoint, + log.CacheTTLOverridden, + channelID, + modelMappingChain, + billingTier, + billingMode, + log.AccountStatsCost, // account_stats_cost + createdAt, + }, + } +} diff --git a/backend/internal/service/admin_dashboard.go b/backend/internal/service/admin_dashboard.go new file mode 100644 index 00000000000..aa6f812801e --- /dev/null +++ b/backend/internal/service/admin_dashboard.go @@ -0,0 +1,26 @@ +package service + +// UserRPMStatus describes a user's current per-minute RPM usage. +type UserRPMStatus struct { + UserRPMUsed int `json:"user_rpm_used"` + UserRPMLimit int `json:"user_rpm_limit"` + PerGroup []UserGroupRPMStatus `json:"per_group"` +} + +// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair. +type UserGroupRPMStatus struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Used int `json:"used"` + Limit int `json:"limit"` + Source string `json:"source"` // "group" | "override" +} + +// BulkUpdateAccountsResult is the aggregated response for bulk updates. +type BulkUpdateAccountsResult struct { + Success int `json:"success"` + Failed int `json:"failed"` + SuccessIDs []int64 `json:"success_ids"` + FailedIDs []int64 `json:"failed_ids"` + Results []BulkUpdateAccountResult `json:"results"` +} diff --git a/backend/internal/service/admin_group.go b/backend/internal/service/admin_group.go new file mode 100644 index 00000000000..2d041c456d9 --- /dev/null +++ b/backend/internal/service/admin_group.go @@ -0,0 +1,729 @@ +package service + +import ( + "context" + "errors" + "fmt" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +type CreateGroupInput struct { + Name string + Description string + Platform string + RateMultiplier float64 + IsExclusive bool + SubscriptionType string // standard/subscription + DailyLimitUSD *float64 // 日限额 (USD) + WeeklyLimitUSD *float64 // 周限额 (USD) + MonthlyLimitUSD *float64 // 月限额 (USD) + // 图片生成计费配置(仅 antigravity 平台使用) + AllowImageGeneration bool + ImageRateIndependent bool + ImageRateMultiplier *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 + ModelRoutingEnabled bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes []string + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch bool + DefaultMappedModel string + RequireOAuthOnly bool + RequirePrivacySet bool + MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制) + RPMLimit int + // 从指定分组复制账号(创建分组后在同一事务内绑定) + CopyAccountsFromGroupIDs []int64 +} + +type UpdateGroupInput struct { + Name string + Description string + Platform string + RateMultiplier *float64 // 使用指针以支持设置为0 + IsExclusive *bool + Status string + SubscriptionType string // standard/subscription + DailyLimitUSD *float64 // 日限额 (USD) + WeeklyLimitUSD *float64 // 周限额 (USD) + MonthlyLimitUSD *float64 // 月限额 (USD) + // 图片生成计费配置(仅 antigravity 平台使用) + AllowImageGeneration *bool + ImageRateIndependent *bool + ImageRateMultiplier *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID + // 无效请求兜底分组 ID(仅 anthropic 平台使用) + FallbackGroupIDOnInvalidRequest *int64 + // 模型路由配置(仅 anthropic 平台使用) + ModelRouting map[string][]int64 + ModelRoutingEnabled *bool // 是否启用模型路由 + MCPXMLInject *bool + // 支持的模型系列(仅 antigravity 平台使用) + SupportedModelScopes *[]string + // OpenAI Messages 调度配置(仅 openai 平台使用) + AllowMessagesDispatch *bool + DefaultMappedModel *string + RequireOAuthOnly *bool + RequirePrivacySet *bool + MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。 + RPMLimit *int + // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) + CopyAccountsFromGroupIDs []int64 +} + +// Group management implementations +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} + groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) + if err != nil { + return nil, 0, err + } + return groups, result.Total, nil +} + +func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { + return s.groupRepo.ListActive(ctx) +} + +func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { + return s.groupRepo.ListActiveByPlatform(ctx, platform) +} + +func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { + return s.groupRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { + if input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } + + platform := input.Platform + if platform == "" { + platform = PlatformAnthropic + } + + subscriptionType := input.SubscriptionType + if subscriptionType == "" { + subscriptionType = SubscriptionTypeStandard + } + + // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 + dailyLimit := normalizeLimit(input.DailyLimitUSD) + weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) + monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) + + // 图片价格:负数表示清除(使用默认价格),0 保留(表示免费) + imagePrice1K := normalizePrice(input.ImagePrice1K) + imagePrice2K := normalizePrice(input.ImagePrice2K) + imagePrice4K := normalizePrice(input.ImagePrice4K) + imageRateMultiplier := 1.0 + if input.ImageRateMultiplier != nil { + if *input.ImageRateMultiplier < 0 { + return nil, errors.New("image_rate_multiplier must be >= 0") + } + imageRateMultiplier = *input.ImageRateMultiplier + } + + // 校验降级分组 + if input.FallbackGroupID != nil { + if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil { + return nil, err + } + } + fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest + if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { + fallbackOnInvalidRequest = nil + } + // 校验无效请求兜底分组 + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + + // MCPXMLInject:默认为 true,仅当显式传入 false 时关闭 + mcpXMLInject := true + if input.MCPXMLInject != nil { + mcpXMLInject = *input.MCPXMLInject + } + + // 如果指定了复制账号的源分组,先获取账号 ID 列表 + var accountIDsToCopy []int64 + if len(input.CopyAccountsFromGroupIDs) > 0 { + // 去重源分组 IDs + seen := make(map[int64]struct{}) + uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) + for _, srcGroupID := range input.CopyAccountsFromGroupIDs { + if _, exists := seen[srcGroupID]; !exists { + seen[srcGroupID] = struct{}{} + uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) + } + } + + // 校验源分组的平台是否与新分组一致 + for _, srcGroupID := range uniqueSourceGroupIDs { + srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) + if err != nil { + return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) + } + if srcGroup.Platform != platform { + return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform) + } + } + + // 获取所有源分组的账号(去重) + var err error + accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) + } + } + + group := &Group{ + Name: input.Name, + Description: input.Description, + Platform: platform, + RateMultiplier: input.RateMultiplier, + IsExclusive: input.IsExclusive, + Status: StatusActive, + SubscriptionType: subscriptionType, + DailyLimitUSD: dailyLimit, + WeeklyLimitUSD: weeklyLimit, + MonthlyLimitUSD: monthlyLimit, + AllowImageGeneration: input.AllowImageGeneration, + ImageRateIndependent: input.ImageRateIndependent, + ImageRateMultiplier: imageRateMultiplier, + ImagePrice1K: imagePrice1K, + ImagePrice2K: imagePrice2K, + ImagePrice4K: imagePrice4K, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, + FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, + ModelRouting: input.ModelRouting, + MCPXMLInject: mcpXMLInject, + SupportedModelScopes: input.SupportedModelScopes, + AllowMessagesDispatch: input.AllowMessagesDispatch, + RequireOAuthOnly: input.RequireOAuthOnly, + RequirePrivacySet: input.RequirePrivacySet, + DefaultMappedModel: input.DefaultMappedModel, + MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), + RPMLimit: input.RPMLimit, + } + sanitizeGroupMessagesDispatchFields(group) + if err := s.groupRepo.Create(ctx, group); err != nil { + return nil, err + } + + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + + // 如果有需要复制的账号,绑定到新分组 + if len(accountIDsToCopy) > 0 { + if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { + return nil, fmt.Errorf("failed to bind accounts to new group: %w", err) + } + group.AccountCount = int64(len(accountIDsToCopy)) + } + + return group, nil +} + +func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { + group, err := s.groupRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + if input.Name != "" { + group.Name = input.Name + } + if input.Description != "" { + group.Description = input.Description + } + if input.Platform != "" { + group.Platform = input.Platform + } + if input.RateMultiplier != nil { + if *input.RateMultiplier <= 0 { + return nil, errors.New("rate_multiplier must be > 0") + } + group.RateMultiplier = *input.RateMultiplier + } + if input.IsExclusive != nil { + group.IsExclusive = *input.IsExclusive + } + if input.Status != "" { + group.Status = input.Status + } + + // 订阅相关字段 + if input.SubscriptionType != "" { + group.SubscriptionType = input.SubscriptionType + } + // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 + // 前端始终发送这三个字段,无需 nil 守卫 + group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) + group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) + group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) + // 图片生成计费配置:负数表示清除(使用默认价格) + if input.AllowImageGeneration != nil { + group.AllowImageGeneration = *input.AllowImageGeneration + } + if input.ImageRateIndependent != nil { + group.ImageRateIndependent = *input.ImageRateIndependent + } + if input.ImageRateMultiplier != nil { + if *input.ImageRateMultiplier < 0 { + return nil, errors.New("image_rate_multiplier must be >= 0") + } + group.ImageRateMultiplier = *input.ImageRateMultiplier + } + if input.ImagePrice1K != nil { + group.ImagePrice1K = normalizePrice(input.ImagePrice1K) + } + if input.ImagePrice2K != nil { + group.ImagePrice2K = normalizePrice(input.ImagePrice2K) + } + if input.ImagePrice4K != nil { + group.ImagePrice4K = normalizePrice(input.ImagePrice4K) + } + + // Claude Code 客户端限制 + if input.ClaudeCodeOnly != nil { + group.ClaudeCodeOnly = *input.ClaudeCodeOnly + } + if input.FallbackGroupID != nil { + // 校验降级分组 + if *input.FallbackGroupID > 0 { + if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil { + return nil, err + } + group.FallbackGroupID = input.FallbackGroupID + } else { + // 传入 0 或负数表示清除降级分组 + group.FallbackGroupID = nil + } + } + fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest + if input.FallbackGroupIDOnInvalidRequest != nil { + if *input.FallbackGroupIDOnInvalidRequest > 0 { + fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest + } else { + fallbackOnInvalidRequest = nil + } + } + if fallbackOnInvalidRequest != nil { + if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { + return nil, err + } + } + group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest + + // 模型路由配置 + if input.ModelRouting != nil { + group.ModelRouting = input.ModelRouting + } + if input.ModelRoutingEnabled != nil { + group.ModelRoutingEnabled = *input.ModelRoutingEnabled + } + if input.MCPXMLInject != nil { + group.MCPXMLInject = *input.MCPXMLInject + } + + // 支持的模型系列(仅 antigravity 平台使用) + if input.SupportedModelScopes != nil { + group.SupportedModelScopes = *input.SupportedModelScopes + } + + // OpenAI Messages 调度配置 + if input.AllowMessagesDispatch != nil { + group.AllowMessagesDispatch = *input.AllowMessagesDispatch + } + if input.RequireOAuthOnly != nil { + group.RequireOAuthOnly = *input.RequireOAuthOnly + } + if input.RequirePrivacySet != nil { + group.RequirePrivacySet = *input.RequirePrivacySet + } + if input.DefaultMappedModel != nil { + group.DefaultMappedModel = *input.DefaultMappedModel + } + if input.MessagesDispatchModelConfig != nil { + group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) + } + if input.RPMLimit != nil { + group.RPMLimit = *input.RPMLimit + } + sanitizeGroupMessagesDispatchFields(group) + + if err := s.groupRepo.Update(ctx, group); err != nil { + return nil, err + } + + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } + + // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) + if len(input.CopyAccountsFromGroupIDs) > 0 { + // 去重源分组 IDs + seen := make(map[int64]struct{}) + uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) + for _, srcGroupID := range input.CopyAccountsFromGroupIDs { + // 校验:源分组不能是自身 + if srcGroupID == id { + return nil, fmt.Errorf("cannot copy accounts from self") + } + // 去重 + if _, exists := seen[srcGroupID]; !exists { + seen[srcGroupID] = struct{}{} + uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) + } + } + + // 校验源分组的平台是否与当前分组一致 + for _, srcGroupID := range uniqueSourceGroupIDs { + srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) + if err != nil { + return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) + } + if srcGroup.Platform != group.Platform { + return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform) + } + } + + // 获取所有源分组的账号(去重) + accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) + if err != nil { + return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) + } + + // 先清空当前分组的所有账号绑定 + if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil { + return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) + } + + // require_oauth_only: 过滤掉 apikey 类型账号 + if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { + accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) + if err != nil { + return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) + } + oauthIDs := make(map[int64]struct{}, len(accounts)) + for _, acc := range accounts { + if acc.Type != AccountTypeAPIKey { + oauthIDs[acc.ID] = struct{}{} + } + } + var filtered []int64 + for _, aid := range accountIDsToCopy { + if _, ok := oauthIDs[aid]; ok { + filtered = append(filtered, aid) + } + } + accountIDsToCopy = filtered + } + + // 再绑定源分组的账号 + if len(accountIDsToCopy) > 0 { + if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { + return nil, fmt.Errorf("failed to bind accounts to group: %w", err) + } + } + } + + return group, nil +} + +func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { + var groupKeys []string + if s.authCacheInvalidator != nil { + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) + if err == nil { + groupKeys = keys + } + } + + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) + if err != nil { + return err + } + // 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理 + + // 事务成功后,异步失效受影响用户的订阅缓存 + if len(affectedUserIDs) > 0 && s.billingCacheService != nil { + groupID := id + go func() { + cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + for _, userID := range affectedUserIDs { + if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { + logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) + } + } + }() + } + if s.authCacheInvalidator != nil { + for _, key := range groupKeys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) + } + } + + return nil +} + +func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) + if err != nil { + return nil, 0, err + } + return keys, result.Total, nil +} + +func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + return s.userGroupRateRepo.GetByGroupID(ctx, groupID) +} + +func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error { + if s.userGroupRateRepo == nil { + return nil + } + for _, e := range entries { + if e.RateMultiplier <= 0 { + return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID) + } + } + return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) +} + +func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error { + if s.userGroupRateRepo == nil { + return nil + } + for _, e := range entries { + if e.RPMOverride != nil && *e.RPMOverride < 0 { + return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID)) + } + } + if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil { + return err + } + // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) + } + return nil +} + +// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定 +// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组 +func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) { + apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) + if err != nil { + return nil, err + } + + if groupID == nil { + // nil 表示不修改,直接返回 + return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil + } + + if *groupID < 0 { + return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative") + } + + result := &AdminUpdateAPIKeyGroupIDResult{} + + if *groupID == 0 { + // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key) + apiKey.GroupID = nil + apiKey.Group = nil + } else { + // 验证目标分组存在且状态为 active + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return nil, err + } + if group.Status != StatusActive { + return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") + } + // 订阅类型分组:用户须持有该分组的有效订阅才可绑定 + if group.IsSubscriptionType() { + if s.userSubRepo == nil { + return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured") + } + if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil { + if errors.Is(err, ErrSubscriptionNotFound) { + return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group") + } + return nil, err + } + } + + gid := *groupID + apiKey.GroupID = &gid + apiKey.Group = group + + // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 + if group.IsExclusive && !group.IsSubscriptionType() { + opCtx := ctx + var tx *dbent.Tx + if s.entClient == nil { + logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding") + } else { + var txErr error + tx, txErr = s.entClient.Tx(ctx) + if txErr != nil { + return nil, fmt.Errorf("begin transaction: %w", txErr) + } + defer func() { _ = tx.Rollback() }() + opCtx = dbent.NewTxContext(ctx, tx) + } + + if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil { + return nil, fmt.Errorf("add group to user allowed groups: %w", addErr) + } + if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + } + + result.AutoGrantedGroupAccess = true + result.GrantedGroupID = &gid + result.GrantedGroupName = group.Name + + // 失效认证缓存(在事务提交后执行) + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil + } + } + + // 非专属分组 / 解绑:无需事务,单步更新即可 + if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { + return nil, fmt.Errorf("update api key: %w", err) + } + + // 失效认证缓存 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) + } + + result.APIKey = apiKey + return result, nil +} + +// ReplaceUserGroup 替换用户的专属分组 +func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { + if oldGroupID == newGroupID { + return nil, infraerrors.BadRequest("SAME_GROUP", "old and new group must be different") + } + + // 验证新分组存在且为活跃的专属标准分组 + newGroup, err := s.groupRepo.GetByID(ctx, newGroupID) + if err != nil { + return nil, err + } + if newGroup.Status != StatusActive { + return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") + } + if !newGroup.IsExclusive { + return nil, infraerrors.BadRequest("GROUP_NOT_EXCLUSIVE", "target group is not exclusive") + } + if newGroup.IsSubscriptionType() { + return nil, infraerrors.BadRequest("GROUP_IS_SUBSCRIPTION", "subscription groups are not supported for replacement") + } + + // 事务保证原子性 + if s.entClient == nil { + return nil, fmt.Errorf("entClient is nil, cannot perform group replacement") + } + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, fmt.Errorf("begin transaction: %w", err) + } + defer func() { _ = tx.Rollback() }() + opCtx := dbent.NewTxContext(ctx, tx) + + // 1. 授予新分组权限 + if err := s.userRepo.AddGroupToAllowedGroups(opCtx, userID, newGroupID); err != nil { + return nil, fmt.Errorf("add new group to allowed groups: %w", err) + } + + // 2. 迁移绑定旧分组的 Key 到新分组 + migrated, err := s.apiKeyRepo.UpdateGroupIDByUserAndGroup(opCtx, userID, oldGroupID, newGroupID) + if err != nil { + return nil, fmt.Errorf("migrate api keys: %w", err) + } + + // 3. 移除旧分组权限 + if err := s.userRepo.RemoveGroupFromUserAllowedGroups(opCtx, userID, oldGroupID); err != nil { + return nil, fmt.Errorf("remove old group from allowed groups: %w", err) + } + + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + + // 失效该用户所有 Key 的认证缓存 + if s.authCacheInvalidator != nil { + keys, keyErr := s.apiKeyRepo.ListKeysByUserID(ctx, userID) + if keyErr == nil { + for _, k := range keys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, k) + } + } + } + + return &ReplaceUserGroupResult{MigratedKeys: migrated}, nil +} diff --git a/backend/internal/service/admin_payment_ops.go b/backend/internal/service/admin_payment_ops.go new file mode 100644 index 00000000000..da17d639688 --- /dev/null +++ b/backend/internal/service/admin_payment_ops.go @@ -0,0 +1,101 @@ +package service + +import ( + "context" + "errors" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +type GenerateRedeemCodesInput struct { + Count int + Type string + Value float64 + GroupID *int64 // 订阅类型专用:关联的分组ID + ValidityDays int // 订阅类型专用:有效天数 +} + +// Redeem code management implementations +func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} + codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) + if err != nil { + return nil, 0, err + } + return codes, result.Total, nil +} + +func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { + return s.redeemCodeRepo.GetByID(ctx, id) +} + +func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { + // 如果是订阅类型,验证必须有 GroupID + if input.Type == RedeemTypeSubscription { + if input.GroupID == nil { + return nil, errors.New("group_id is required for subscription type") + } + // 验证分组存在且为订阅类型 + group, err := s.groupRepo.GetByID(ctx, *input.GroupID) + if err != nil { + return nil, fmt.Errorf("group not found: %w", err) + } + if !group.IsSubscriptionType() { + return nil, errors.New("group must be subscription type") + } + } + + codes := make([]RedeemCode, 0, input.Count) + for i := 0; i < input.Count; i++ { + codeValue, err := GenerateRedeemCode() + if err != nil { + return nil, err + } + code := RedeemCode{ + Code: codeValue, + Type: input.Type, + Value: input.Value, + Status: StatusUnused, + } + // 订阅类型专用字段 + if input.Type == RedeemTypeSubscription { + code.GroupID = input.GroupID + code.ValidityDays = input.ValidityDays + if code.ValidityDays <= 0 { + code.ValidityDays = 30 // 默认30天 + } + } + if err := s.redeemCodeRepo.Create(ctx, &code); err != nil { + return nil, err + } + codes = append(codes, code) + } + return codes, nil +} + +func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error { + return s.redeemCodeRepo.Delete(ctx, id) +} + +func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { + var deleted int64 + for _, id := range ids { + if err := s.redeemCodeRepo.Delete(ctx, id); err == nil { + deleted++ + } + } + return deleted, nil +} + +func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { + code, err := s.redeemCodeRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + code.Status = StatusExpired + if err := s.redeemCodeRepo.Update(ctx, code); err != nil { + return nil, err + } + return code, nil +} diff --git a/backend/internal/service/admin_probe.go b/backend/internal/service/admin_probe.go new file mode 100644 index 00000000000..0a34689031b --- /dev/null +++ b/backend/internal/service/admin_probe.go @@ -0,0 +1,339 @@ +package service + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +type CreateProxyInput struct { + Name string + Protocol string + Host string + Port int + Username string + Password string +} + +type UpdateProxyInput struct { + Name string + Protocol string + Host string + Port int + Username string + Password string + Status string +} + +// ProxyTestResult represents the result of testing a proxy +type ProxyTestResult struct { + Success bool `json:"success"` + Message string `json:"message"` + LatencyMs int64 `json:"latency_ms,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + City string `json:"city,omitempty"` + Region string `json:"region,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` +} + +type ProxyQualityCheckResult struct { + ProxyID int64 `json:"proxy_id"` + Score int `json:"score"` + Grade string `json:"grade"` + Summary string `json:"summary"` + ExitIP string `json:"exit_ip,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` + PassedCount int `json:"passed_count"` + WarnCount int `json:"warn_count"` + FailedCount int `json:"failed_count"` + ChallengeCount int `json:"challenge_count"` + CheckedAt int64 `json:"checked_at"` + Items []ProxyQualityCheckItem `json:"items"` +} + +type ProxyQualityCheckItem struct { + Target string `json:"target"` + Status string `json:"status"` // pass/warn/fail/challenge + HTTPStatus int `json:"http_status,omitempty"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + CFRay string `json:"cf_ray,omitempty"` +} + +// ProxyExitInfo represents proxy exit information from ip-api.com +type ProxyExitInfo struct { + IP string + City string + Region string + Country string + CountryCode string +} + +func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + proxyURL := proxy.URL() + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ + Success: false, + Message: err.Error(), + UpdatedAt: time.Now(), + }) + return &ProxyTestResult{ + Success: false, + Message: err.Error(), + }, nil + } + + latency := latencyMs + s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ + Success: true, + LatencyMs: &latency, + Message: "Proxy is accessible", + IPAddress: exitInfo.IP, + Country: exitInfo.Country, + CountryCode: exitInfo.CountryCode, + Region: exitInfo.Region, + City: exitInfo.City, + UpdatedAt: time.Now(), + }) + return &ProxyTestResult{ + Success: true, + Message: "Proxy is accessible", + LatencyMs: latencyMs, + IPAddress: exitInfo.IP, + City: exitInfo.City, + Region: exitInfo.Region, + Country: exitInfo.Country, + CountryCode: exitInfo.CountryCode, + }, nil +} + +func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + result := &ProxyQualityCheckResult{ + ProxyID: id, + Score: 100, + Grade: "A", + CheckedAt: time.Now().Unix(), + Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), + } + + proxyURL := proxy.URL() + if s.proxyProber == nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + Message: "代理探测服务未配置", + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + LatencyMs: latencyMs, + Message: err.Error(), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + result.ExitIP = exitInfo.IP + result.Country = exitInfo.Country + result.CountryCode = exitInfo.CountryCode + result.BaseLatencyMs = latencyMs + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "pass", + LatencyMs: latencyMs, + Message: "代理出口连通正常", + }) + result.PassedCount++ + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: proxyQualityRequestTimeout, + ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, + }) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "http_client", + Status: "fail", + Message: fmt.Sprintf("创建检测客户端失败: %v", err), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil + } + + for _, target := range proxyQualityTargets { + item := runProxyQualityTarget(ctx, client, target) + result.Items = append(result.Items, item) + switch item.Status { + case "pass": + result.PassedCount++ + case "warn": + result.WarnCount++ + case "challenge": + result.ChallengeCount++ + default: + result.FailedCount++ + } + } + + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil +} + +func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { + if result == nil { + return + } + score := result.Score + checkedAt := result.CheckedAt + info := &ProxyLatencyInfo{ + Success: proxyQualityBaseConnectivityPass(result), + Message: result.Summary, + QualityStatus: proxyQualityOverallStatus(result), + QualityScore: &score, + QualityGrade: result.Grade, + QualitySummary: result.Summary, + QualityCheckedAt: &checkedAt, + QualityCFRay: proxyQualityFirstCFRay(result), + UpdatedAt: time.Now(), + } + if result.BaseLatencyMs > 0 { + latency := result.BaseLatencyMs + info.LatencyMs = &latency + } + if exitInfo != nil { + info.IPAddress = exitInfo.IP + info.Country = exitInfo.Country + info.CountryCode = exitInfo.CountryCode + info.Region = exitInfo.Region + info.City = exitInfo.City + } + s.saveProxyLatency(ctx, proxyID, info) +} + +func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { + if s.proxyProber == nil || proxy == nil { + return + } + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL()) + if err != nil { + s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ + Success: false, + Message: err.Error(), + UpdatedAt: time.Now(), + }) + return + } + + latency := latencyMs + s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ + Success: true, + LatencyMs: &latency, + Message: "Proxy is accessible", + IPAddress: exitInfo.IP, + Country: exitInfo.Country, + CountryCode: exitInfo.CountryCode, + Region: exitInfo.Region, + City: exitInfo.City, + UpdatedAt: time.Now(), + }) +} + +func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { + if s.proxyLatencyCache == nil || len(proxies) == 0 { + return + } + + ids := make([]int64, 0, len(proxies)) + for i := range proxies { + ids = append(ids, proxies[i].ID) + } + + latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) + if err != nil { + logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err) + return + } + + for i := range proxies { + info := latencies[proxies[i].ID] + if info == nil { + continue + } + if info.Success { + proxies[i].LatencyStatus = "success" + proxies[i].LatencyMs = info.LatencyMs + } else { + proxies[i].LatencyStatus = "failed" + } + proxies[i].LatencyMessage = info.Message + proxies[i].IPAddress = info.IPAddress + proxies[i].Country = info.Country + proxies[i].CountryCode = info.CountryCode + proxies[i].Region = info.Region + proxies[i].City = info.City + proxies[i].QualityStatus = info.QualityStatus + proxies[i].QualityScore = info.QualityScore + proxies[i].QualityGrade = info.QualityGrade + proxies[i].QualitySummary = info.QualitySummary + proxies[i].QualityChecked = info.QualityCheckedAt + } +} + +func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) { + if s.proxyLatencyCache == nil || info == nil { + return + } + + merged := *info + if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { + if existing := latencies[proxyID]; existing != nil { + if merged.QualityCheckedAt == nil && + merged.QualityScore == nil && + merged.QualityGrade == "" && + merged.QualityStatus == "" && + merged.QualitySummary == "" && + merged.QualityCFRay == "" { + merged.QualityStatus = existing.QualityStatus + merged.QualityScore = existing.QualityScore + merged.QualityGrade = existing.QualityGrade + merged.QualitySummary = existing.QualitySummary + merged.QualityCheckedAt = existing.QualityCheckedAt + merged.QualityCFRay = existing.QualityCFRay + } + } + } + + if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { + logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index eb5994d5498..224b7bd2588 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -15,10 +15,7 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/authidentity" - "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/util/httputil" @@ -179,85 +176,6 @@ type AdminBoundAuthIdentityChannel struct { UpdatedAt time.Time `json:"updated_at"` } -type CreateGroupInput struct { - Name string - Description string - Platform string - RateMultiplier float64 - IsExclusive bool - SubscriptionType string // standard/subscription - DailyLimitUSD *float64 // 日限额 (USD) - WeeklyLimitUSD *float64 // 周限额 (USD) - MonthlyLimitUSD *float64 // 月限额 (USD) - // 图片生成计费配置(仅 antigravity 平台使用) - AllowImageGeneration bool - ImageRateIndependent bool - ImageRateMultiplier *float64 - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID - // 无效请求兜底分组 ID(仅 anthropic 平台使用) - FallbackGroupIDOnInvalidRequest *int64 - // 模型路由配置(仅 anthropic 平台使用) - ModelRouting map[string][]int64 - ModelRoutingEnabled bool // 是否启用模型路由 - MCPXMLInject *bool - // 支持的模型系列(仅 antigravity 平台使用) - SupportedModelScopes []string - // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch bool - DefaultMappedModel string - RequireOAuthOnly bool - RequirePrivacySet bool - MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig - // RPMLimit 分组 RPM 上限(0 = 不限制) - RPMLimit int - // 从指定分组复制账号(创建分组后在同一事务内绑定) - CopyAccountsFromGroupIDs []int64 -} - -type UpdateGroupInput struct { - Name string - Description string - Platform string - RateMultiplier *float64 // 使用指针以支持设置为0 - IsExclusive *bool - Status string - SubscriptionType string // standard/subscription - DailyLimitUSD *float64 // 日限额 (USD) - WeeklyLimitUSD *float64 // 周限额 (USD) - MonthlyLimitUSD *float64 // 月限额 (USD) - // 图片生成计费配置(仅 antigravity 平台使用) - AllowImageGeneration *bool - ImageRateIndependent *bool - ImageRateMultiplier *float64 - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID - // 无效请求兜底分组 ID(仅 anthropic 平台使用) - FallbackGroupIDOnInvalidRequest *int64 - // 模型路由配置(仅 anthropic 平台使用) - ModelRouting map[string][]int64 - ModelRoutingEnabled *bool // 是否启用模型路由 - MCPXMLInject *bool - // 支持的模型系列(仅 antigravity 平台使用) - SupportedModelScopes *[]string - // OpenAI Messages 调度配置(仅 openai 平台使用) - AllowMessagesDispatch *bool - DefaultMappedModel *string - RequireOAuthOnly *bool - RequirePrivacySet *bool - MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig - // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。 - RPMLimit *int - // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) - CopyAccountsFromGroupIDs []int64 -} - type CreateAccountInput struct { Name string Notes *string @@ -347,58 +265,6 @@ type ReplaceUserGroupResult struct { MigratedKeys int64 // 迁移的 Key 数量 } -// UserRPMStatus describes a user's current per-minute RPM usage. -type UserRPMStatus struct { - UserRPMUsed int `json:"user_rpm_used"` - UserRPMLimit int `json:"user_rpm_limit"` - PerGroup []UserGroupRPMStatus `json:"per_group"` -} - -// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair. -type UserGroupRPMStatus struct { - GroupID int64 `json:"group_id"` - GroupName string `json:"group_name"` - Used int `json:"used"` - Limit int `json:"limit"` - Source string `json:"source"` // "group" | "override" -} - -// BulkUpdateAccountsResult is the aggregated response for bulk updates. -type BulkUpdateAccountsResult struct { - Success int `json:"success"` - Failed int `json:"failed"` - SuccessIDs []int64 `json:"success_ids"` - FailedIDs []int64 `json:"failed_ids"` - Results []BulkUpdateAccountResult `json:"results"` -} - -type CreateProxyInput struct { - Name string - Protocol string - Host string - Port int - Username string - Password string -} - -type UpdateProxyInput struct { - Name string - Protocol string - Host string - Port int - Username string - Password string - Status string -} - -type GenerateRedeemCodesInput struct { - Count int - Type string - Value float64 - GroupID *int64 // 订阅类型专用:关联的分组ID - ValidityDays int // 订阅类型专用:有效天数 -} - type ProxyBatchDeleteResult struct { DeletedIDs []int64 `json:"deleted_ids"` Skipped []ProxyBatchDeleteSkipped `json:"skipped"` @@ -409,53 +275,6 @@ type ProxyBatchDeleteSkipped struct { Reason string `json:"reason"` } -// ProxyTestResult represents the result of testing a proxy -type ProxyTestResult struct { - Success bool `json:"success"` - Message string `json:"message"` - LatencyMs int64 `json:"latency_ms,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - City string `json:"city,omitempty"` - Region string `json:"region,omitempty"` - Country string `json:"country,omitempty"` - CountryCode string `json:"country_code,omitempty"` -} - -type ProxyQualityCheckResult struct { - ProxyID int64 `json:"proxy_id"` - Score int `json:"score"` - Grade string `json:"grade"` - Summary string `json:"summary"` - ExitIP string `json:"exit_ip,omitempty"` - Country string `json:"country,omitempty"` - CountryCode string `json:"country_code,omitempty"` - BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` - PassedCount int `json:"passed_count"` - WarnCount int `json:"warn_count"` - FailedCount int `json:"failed_count"` - ChallengeCount int `json:"challenge_count"` - CheckedAt int64 `json:"checked_at"` - Items []ProxyQualityCheckItem `json:"items"` -} - -type ProxyQualityCheckItem struct { - Target string `json:"target"` - Status string `json:"status"` // pass/warn/fail/challenge - HTTPStatus int `json:"http_status,omitempty"` - LatencyMs int64 `json:"latency_ms,omitempty"` - Message string `json:"message,omitempty"` - CFRay string `json:"cf_ray,omitempty"` -} - -// ProxyExitInfo represents proxy exit information from ip-api.com -type ProxyExitInfo struct { - IP string - City string - Region string - Country string - CountryCode string -} - // ProxyExitInfoProber tests proxy connectivity and retrieves exit information type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) @@ -577,52 +396,6 @@ func NewAdminService( } } -// User management implementations -func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} - users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) - if err != nil { - return nil, 0, err - } - if len(users) > 0 { - userIDs := make([]int64, 0, len(users)) - for i := range users { - userIDs = append(userIDs, users[i].ID) - } - lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs) - if latestErr != nil { - logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr) - } else { - for i := range users { - users[i].LastUsedAt = lastUsedByUserID[users[i].ID] - } - } - } - // 批量加载用户专属分组倍率 - if s.userGroupRateRepo != nil && len(users) > 0 { - if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { - userIDs := make([]int64, 0, len(users)) - for i := range users { - userIDs = append(userIDs, users[i].ID) - } - ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) - if err != nil { - logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) - s.loadUserGroupRatesOneByOne(ctx, users) - } else { - for i := range users { - if rates, ok := ratesByUser[users[i].ID]; ok { - users[i].GroupRates = rates - } - } - } - } else { - s.loadUserGroupRatesOneByOne(ctx, users) - } - } - return users, result.Total, nil -} - func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) { if s.userGroupRateRepo == nil { return @@ -637,51 +410,6 @@ func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users } } -func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { - user, err := s.userRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id) - if latestErr != nil { - logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr) - } else { - user.LastUsedAt = lastUsedAt - } - // 加载用户专属分组倍率 - if s.userGroupRateRepo != nil { - rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) - if err != nil { - logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err) - } else { - user.GroupRates = rates - } - } - return user, nil -} - -func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { - user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - RPMLimit: input.RPMLimit, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, - } - if err := user.SetPassword(input.Password); err != nil { - return nil, err - } - if err := s.userRepo.Create(ctx, user); err != nil { - return nil, err - } - s.assignDefaultSubscriptions(ctx, user.ID) - return user, nil -} - func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) { if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { return @@ -699,125 +427,6 @@ func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userI } } -func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { - // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率) - if input.GroupRates != nil { - for groupID, rate := range input.GroupRates { - if rate != nil && *rate <= 0 { - return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID) - } - } - } - - user, err := s.userRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - // Protect admin users: cannot disable admin accounts - if user.Role == "admin" && input.Status == "disabled" { - return nil, errors.New("cannot disable admin user") - } - - oldConcurrency := user.Concurrency - oldStatus := user.Status - oldRole := user.Role - oldRPMLimit := user.RPMLimit - - if input.Email != "" { - user.Email = input.Email - } - if input.Password != "" { - if err := user.SetPassword(input.Password); err != nil { - return nil, err - } - } - - if input.Username != nil { - user.Username = *input.Username - } - if input.Notes != nil { - user.Notes = *input.Notes - } - - if input.Status != "" { - user.Status = input.Status - } - - if input.Concurrency != nil { - user.Concurrency = *input.Concurrency - } - - if input.RPMLimit != nil { - user.RPMLimit = *input.RPMLimit - } - - if input.AllowedGroups != nil { - user.AllowedGroups = *input.AllowedGroups - } - - if err := s.userRepo.Update(ctx, user); err != nil { - return nil, err - } - - // 同步用户专属分组倍率 - if input.GroupRates != nil && s.userGroupRateRepo != nil { - if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { - logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err) - } - } - - if s.authCacheInvalidator != nil { - // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联, - // 不失效缓存会让修改在一个 L2 TTL 内失去效果。 - if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit { - s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) - } - } - - concurrencyDiff := user.Concurrency - oldConcurrency - if concurrencyDiff != 0 { - code, err := GenerateRedeemCode() - if err != nil { - logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) - return user, nil - } - adjustmentRecord := &RedeemCode{ - Code: code, - Type: AdjustmentTypeAdminConcurrency, - Value: float64(concurrencyDiff), - Status: StatusUsed, - UsedBy: &user.ID, - } - now := time.Now() - adjustmentRecord.UsedAt = &now - if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err) - } - } - - return user, nil -} - -func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { - // Protect admin users: cannot delete admin accounts - user, err := s.userRepo.GetByID(ctx, id) - if err != nil { - return err - } - if user.Role == "admin" { - return errors.New("cannot delete admin user") - } - if err := s.userRepo.Delete(ctx, id); err != nil { - logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err) - return err - } - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) - } - return nil -} - func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) { cleaned := make([]int64, 0, len(userIDs)) for _, uid := range userIDs { @@ -916,133 +525,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, return user, nil } -func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} - keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) - if err != nil { - return nil, 0, err - } - return keys, result.Total, nil -} - -func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) { - if s.userRPMCache == nil { - return nil, ErrRPMStatusUnavailable - } - - user, err := s.userRepo.GetByID(ctx, userID) - if err != nil { - return nil, err - } - - userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID) - if err != nil { - logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err) - } - - keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "") - if err != nil { - return nil, err - } - - groupIDSet := make(map[int64]struct{}) - for _, key := range keys { - if key.GroupID != nil && *key.GroupID > 0 { - groupIDSet[*key.GroupID] = struct{}{} - } - } - - groupIDs := make([]int64, 0, len(groupIDSet)) - for groupID := range groupIDSet { - groupIDs = append(groupIDs, groupID) - } - sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] }) - - var perGroup []UserGroupRPMStatus - for _, groupID := range groupIDs { - used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID) - if getErr != nil { - logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr) - } - - entry := UserGroupRPMStatus{ - GroupID: groupID, - Used: used, - } - - if s.groupRepo != nil { - if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil { - entry.GroupName = group.Name - entry.Limit = group.RPMLimit - entry.Source = "group" - } else if groupErr != nil { - logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr) - } - } - - if s.userGroupRateRepo != nil { - override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID) - if overrideErr != nil { - logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr) - } else if override != nil { - entry.Limit = *override - entry.Source = "override" - } - } - - perGroup = append(perGroup, entry) - } - - return &UserRPMStatus{ - UserRPMUsed: userRPMUsed, - UserRPMLimit: user.RPMLimit, - PerGroup: perGroup, - }, nil -} - -func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { - // Return mock data for now - return map[string]any{ - "period": period, - "total_requests": 0, - "total_cost": 0.0, - "total_tokens": 0, - "avg_duration_ms": 0, - }, nil -} - -// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. -func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - if codeType == RedeemTypeAffiliateBalance { - codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params) - if err != nil { - return nil, 0, 0, err - } - totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) - if err != nil { - return nil, 0, 0, err - } - return codes, total, totalRecharged, nil - } - - if codeType == "" { - return s.getAllUserBalanceHistory(ctx, userID, params) - } - - codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) - if err != nil { - return nil, 0, 0, err - } - total := result.Total - // Aggregate total recharged amount (only once, regardless of type filter) - totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) - if err != nil { - return nil, 0, 0, err - } - return codes, total, totalRecharged, nil -} - func (s *adminServiceImpl) getAllUserBalanceHistory(ctx context.Context, userID int64, params pagination.PaginationParams) ([]RedeemCode, int64, float64, error) { needed := params.Offset() + params.Limit() if needed < params.Limit() { @@ -1223,151 +705,6 @@ func redeemCodeHistoryTime(code RedeemCode) time.Time { return code.CreatedAt } -func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { - if userID <= 0 { - return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") - } - if s == nil || s.entClient == nil || s.userRepo == nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable") - } - if _, err := s.userRepo.GetByID(ctx, userID); err != nil { - return nil, err - } - - providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType) - providerKey := strings.TrimSpace(input.ProviderKey) - providerSubject := strings.TrimSpace(input.ProviderSubject) - if providerType == "" { - return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") - } - if providerKey == "" || providerSubject == "" { - return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") - } - canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey) - compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey) - - var issuer *string - if input.Issuer != nil { - trimmed := strings.TrimSpace(*input.Issuer) - if trimmed != "" { - issuer = &trimmed - } - } - - channelInput := normalizeAdminBindChannelInput(input.Channel) - if input.Channel != nil && channelInput == nil { - return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided") - } - - verifiedAt := time.Now().UTC() - tx, err := s.entClient.Tx(ctx) - if err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err) - } - defer func() { _ = tx.Rollback() }() - - identityRecords, err := tx.AuthIdentity.Query(). - Where( - authidentity.ProviderTypeEQ(providerType), - authidentity.ProviderKeyIn(compatibleProviderKeys...), - authidentity.ProviderSubjectEQ(providerSubject), - ). - All(ctx) - if err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) - } - if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) { - return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") - } - identity := selectOwnedAdminAuthIdentity(identityRecords, userID) - - if identity == nil { - create := tx.AuthIdentity.Create(). - SetUserID(userID). - SetProviderType(providerType). - SetProviderKey(canonicalProviderKey). - SetProviderSubject(providerSubject). - SetVerifiedAt(verifiedAt) - if issuer != nil { - create = create.SetIssuer(*issuer) - } - if input.Metadata != nil { - create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) - } - identity, err = create.Save(ctx) - if err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) - } - } else { - update := tx.AuthIdentity.UpdateOneID(identity.ID). - SetVerifiedAt(verifiedAt). - SetProviderKey(canonicalProviderKey) - if issuer != nil { - update = update.SetIssuer(*issuer) - } - if input.Metadata != nil { - update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) - } - identity, err = update.Save(ctx) - if err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) - } - } - - var channel *dbent.AuthIdentityChannel - if channelInput != nil { - channelRecords, err := tx.AuthIdentityChannel.Query(). - Where( - authidentitychannel.ProviderTypeEQ(providerType), - authidentitychannel.ProviderKeyIn(compatibleProviderKeys...), - authidentitychannel.ChannelEQ(channelInput.Channel), - authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), - authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), - ). - WithIdentity(). - All(ctx) - if err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) - } - if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) { - return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") - } - channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID) - if channel == nil { - create := tx.AuthIdentityChannel.Create(). - SetIdentityID(identity.ID). - SetProviderType(providerType). - SetProviderKey(canonicalProviderKey). - SetChannel(channelInput.Channel). - SetChannelAppID(channelInput.ChannelAppID). - SetChannelSubject(channelInput.ChannelSubject) - if channelInput.Metadata != nil { - create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) - } - channel, err = create.Save(ctx) - if err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) - } - } else { - update := tx.AuthIdentityChannel.UpdateOneID(channel.ID). - SetIdentityID(identity.ID). - SetProviderKey(canonicalProviderKey) - if channelInput.Metadata != nil { - update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) - } - channel, err = update.Save(ctx) - if err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) - } - } - } - - if err := tx.Commit(); err != nil { - return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err) - } - return buildAdminBoundAuthIdentity(identity, channel), nil -} - func compatibleAdminAuthIdentityProviderKeys(providerType, providerKey string) []string { providerType = strings.TrimSpace(strings.ToLower(providerType)) providerKey = strings.TrimSpace(providerKey) @@ -1551,182 +888,6 @@ func cloneAdminAuthIdentityMetadata(input map[string]any) map[string]any { return out } -// Group management implementations -func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool, sortBy, sortOrder string) ([]Group, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} - groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) - if err != nil { - return nil, 0, err - } - return groups, result.Total, nil -} - -func (s *adminServiceImpl) GetAllGroups(ctx context.Context) ([]Group, error) { - return s.groupRepo.ListActive(ctx) -} - -func (s *adminServiceImpl) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) { - return s.groupRepo.ListActiveByPlatform(ctx, platform) -} - -func (s *adminServiceImpl) GetGroup(ctx context.Context, id int64) (*Group, error) { - return s.groupRepo.GetByID(ctx, id) -} - -func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error) { - if input.RateMultiplier <= 0 { - return nil, errors.New("rate_multiplier must be > 0") - } - - platform := input.Platform - if platform == "" { - platform = PlatformAnthropic - } - - subscriptionType := input.SubscriptionType - if subscriptionType == "" { - subscriptionType = SubscriptionTypeStandard - } - - // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 - dailyLimit := normalizeLimit(input.DailyLimitUSD) - weeklyLimit := normalizeLimit(input.WeeklyLimitUSD) - monthlyLimit := normalizeLimit(input.MonthlyLimitUSD) - - // 图片价格:负数表示清除(使用默认价格),0 保留(表示免费) - imagePrice1K := normalizePrice(input.ImagePrice1K) - imagePrice2K := normalizePrice(input.ImagePrice2K) - imagePrice4K := normalizePrice(input.ImagePrice4K) - imageRateMultiplier := 1.0 - if input.ImageRateMultiplier != nil { - if *input.ImageRateMultiplier < 0 { - return nil, errors.New("image_rate_multiplier must be >= 0") - } - imageRateMultiplier = *input.ImageRateMultiplier - } - - // 校验降级分组 - if input.FallbackGroupID != nil { - if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil { - return nil, err - } - } - fallbackOnInvalidRequest := input.FallbackGroupIDOnInvalidRequest - if fallbackOnInvalidRequest != nil && *fallbackOnInvalidRequest <= 0 { - fallbackOnInvalidRequest = nil - } - // 校验无效请求兜底分组 - if fallbackOnInvalidRequest != nil { - if err := s.validateFallbackGroupOnInvalidRequest(ctx, 0, platform, subscriptionType, *fallbackOnInvalidRequest); err != nil { - return nil, err - } - } - - // MCPXMLInject:默认为 true,仅当显式传入 false 时关闭 - mcpXMLInject := true - if input.MCPXMLInject != nil { - mcpXMLInject = *input.MCPXMLInject - } - - // 如果指定了复制账号的源分组,先获取账号 ID 列表 - var accountIDsToCopy []int64 - if len(input.CopyAccountsFromGroupIDs) > 0 { - // 去重源分组 IDs - seen := make(map[int64]struct{}) - uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) - for _, srcGroupID := range input.CopyAccountsFromGroupIDs { - if _, exists := seen[srcGroupID]; !exists { - seen[srcGroupID] = struct{}{} - uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) - } - } - - // 校验源分组的平台是否与新分组一致 - for _, srcGroupID := range uniqueSourceGroupIDs { - srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) - if err != nil { - return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) - } - if srcGroup.Platform != platform { - return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, platform, srcGroup.Platform) - } - } - - // 获取所有源分组的账号(去重) - var err error - accountIDsToCopy, err = s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) - if err != nil { - return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) - } - } - - group := &Group{ - Name: input.Name, - Description: input.Description, - Platform: platform, - RateMultiplier: input.RateMultiplier, - IsExclusive: input.IsExclusive, - Status: StatusActive, - SubscriptionType: subscriptionType, - DailyLimitUSD: dailyLimit, - WeeklyLimitUSD: weeklyLimit, - MonthlyLimitUSD: monthlyLimit, - AllowImageGeneration: input.AllowImageGeneration, - ImageRateIndependent: input.ImageRateIndependent, - ImageRateMultiplier: imageRateMultiplier, - ImagePrice1K: imagePrice1K, - ImagePrice2K: imagePrice2K, - ImagePrice4K: imagePrice4K, - ClaudeCodeOnly: input.ClaudeCodeOnly, - FallbackGroupID: input.FallbackGroupID, - FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, - ModelRouting: input.ModelRouting, - MCPXMLInject: mcpXMLInject, - SupportedModelScopes: input.SupportedModelScopes, - AllowMessagesDispatch: input.AllowMessagesDispatch, - RequireOAuthOnly: input.RequireOAuthOnly, - RequirePrivacySet: input.RequirePrivacySet, - DefaultMappedModel: input.DefaultMappedModel, - MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), - RPMLimit: input.RPMLimit, - } - sanitizeGroupMessagesDispatchFields(group) - if err := s.groupRepo.Create(ctx, group); err != nil { - return nil, err - } - - // require_oauth_only: 过滤掉 apikey 类型账号 - if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { - accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) - if err != nil { - return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) - } - oauthIDs := make(map[int64]struct{}, len(accounts)) - for _, acc := range accounts { - if acc.Type != AccountTypeAPIKey { - oauthIDs[acc.ID] = struct{}{} - } - } - var filtered []int64 - for _, aid := range accountIDsToCopy { - if _, ok := oauthIDs[aid]; ok { - filtered = append(filtered, aid) - } - } - accountIDsToCopy = filtered - } - - // 如果有需要复制的账号,绑定到新分组 - if len(accountIDsToCopy) > 0 { - if err := s.groupRepo.BindAccountsToGroup(ctx, group.ID, accountIDsToCopy); err != nil { - return nil, fmt.Errorf("failed to bind accounts to new group: %w", err) - } - group.AccountCount = int64(len(accountIDsToCopy)) - } - - return group, nil -} - // normalizeLimit 将负数转换为 nil(表示无限制),0 保留(表示限额为零) func normalizeLimit(limit *float64) *float64 { if limit == nil || *limit < 0 { @@ -1774,302 +935,42 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro return fmt.Errorf("fallback group cannot have claude_code_only enabled") } - if fallbackGroup.FallbackGroupID == nil { - return nil - } - nextID = *fallbackGroup.FallbackGroupID - } -} - -// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性 -// currentGroupID: 当前分组 ID(新建时为 0) -// platform/subscriptionType: 当前分组的有效平台/订阅类型 -// fallbackGroupID: 兜底分组 ID -func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { - if platform != PlatformAnthropic && platform != PlatformAntigravity { - return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") - } - if subscriptionType == SubscriptionTypeSubscription { - return fmt.Errorf("subscription groups cannot set invalid request fallback") - } - if currentGroupID > 0 && currentGroupID == fallbackGroupID { - return fmt.Errorf("cannot set self as invalid request fallback group") - } - - fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) - if err != nil { - return fmt.Errorf("fallback group not found: %w", err) - } - if fallbackGroup.Platform != PlatformAnthropic { - return fmt.Errorf("fallback group must be anthropic platform") - } - if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { - return fmt.Errorf("fallback group cannot be subscription type") - } - if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { - return fmt.Errorf("fallback group cannot have invalid request fallback configured") - } - return nil -} - -func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { - group, err := s.groupRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - if input.Name != "" { - group.Name = input.Name - } - if input.Description != "" { - group.Description = input.Description - } - if input.Platform != "" { - group.Platform = input.Platform - } - if input.RateMultiplier != nil { - if *input.RateMultiplier <= 0 { - return nil, errors.New("rate_multiplier must be > 0") - } - group.RateMultiplier = *input.RateMultiplier - } - if input.IsExclusive != nil { - group.IsExclusive = *input.IsExclusive - } - if input.Status != "" { - group.Status = input.Status - } - - // 订阅相关字段 - if input.SubscriptionType != "" { - group.SubscriptionType = input.SubscriptionType - } - // 限额字段:nil/负数 表示"无限制",0 表示"不允许用量",正数表示具体限额 - // 前端始终发送这三个字段,无需 nil 守卫 - group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD) - group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD) - group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD) - // 图片生成计费配置:负数表示清除(使用默认价格) - if input.AllowImageGeneration != nil { - group.AllowImageGeneration = *input.AllowImageGeneration - } - if input.ImageRateIndependent != nil { - group.ImageRateIndependent = *input.ImageRateIndependent - } - if input.ImageRateMultiplier != nil { - if *input.ImageRateMultiplier < 0 { - return nil, errors.New("image_rate_multiplier must be >= 0") - } - group.ImageRateMultiplier = *input.ImageRateMultiplier - } - if input.ImagePrice1K != nil { - group.ImagePrice1K = normalizePrice(input.ImagePrice1K) - } - if input.ImagePrice2K != nil { - group.ImagePrice2K = normalizePrice(input.ImagePrice2K) - } - if input.ImagePrice4K != nil { - group.ImagePrice4K = normalizePrice(input.ImagePrice4K) - } - - // Claude Code 客户端限制 - if input.ClaudeCodeOnly != nil { - group.ClaudeCodeOnly = *input.ClaudeCodeOnly - } - if input.FallbackGroupID != nil { - // 校验降级分组 - if *input.FallbackGroupID > 0 { - if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil { - return nil, err - } - group.FallbackGroupID = input.FallbackGroupID - } else { - // 传入 0 或负数表示清除降级分组 - group.FallbackGroupID = nil - } - } - fallbackOnInvalidRequest := group.FallbackGroupIDOnInvalidRequest - if input.FallbackGroupIDOnInvalidRequest != nil { - if *input.FallbackGroupIDOnInvalidRequest > 0 { - fallbackOnInvalidRequest = input.FallbackGroupIDOnInvalidRequest - } else { - fallbackOnInvalidRequest = nil - } - } - if fallbackOnInvalidRequest != nil { - if err := s.validateFallbackGroupOnInvalidRequest(ctx, id, group.Platform, group.SubscriptionType, *fallbackOnInvalidRequest); err != nil { - return nil, err - } - } - group.FallbackGroupIDOnInvalidRequest = fallbackOnInvalidRequest - - // 模型路由配置 - if input.ModelRouting != nil { - group.ModelRouting = input.ModelRouting - } - if input.ModelRoutingEnabled != nil { - group.ModelRoutingEnabled = *input.ModelRoutingEnabled - } - if input.MCPXMLInject != nil { - group.MCPXMLInject = *input.MCPXMLInject - } - - // 支持的模型系列(仅 antigravity 平台使用) - if input.SupportedModelScopes != nil { - group.SupportedModelScopes = *input.SupportedModelScopes - } - - // OpenAI Messages 调度配置 - if input.AllowMessagesDispatch != nil { - group.AllowMessagesDispatch = *input.AllowMessagesDispatch - } - if input.RequireOAuthOnly != nil { - group.RequireOAuthOnly = *input.RequireOAuthOnly - } - if input.RequirePrivacySet != nil { - group.RequirePrivacySet = *input.RequirePrivacySet - } - if input.DefaultMappedModel != nil { - group.DefaultMappedModel = *input.DefaultMappedModel - } - if input.MessagesDispatchModelConfig != nil { - group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) - } - if input.RPMLimit != nil { - group.RPMLimit = *input.RPMLimit - } - sanitizeGroupMessagesDispatchFields(group) - - if err := s.groupRepo.Update(ctx, group); err != nil { - return nil, err - } - - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) - } - - // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) - if len(input.CopyAccountsFromGroupIDs) > 0 { - // 去重源分组 IDs - seen := make(map[int64]struct{}) - uniqueSourceGroupIDs := make([]int64, 0, len(input.CopyAccountsFromGroupIDs)) - for _, srcGroupID := range input.CopyAccountsFromGroupIDs { - // 校验:源分组不能是自身 - if srcGroupID == id { - return nil, fmt.Errorf("cannot copy accounts from self") - } - // 去重 - if _, exists := seen[srcGroupID]; !exists { - seen[srcGroupID] = struct{}{} - uniqueSourceGroupIDs = append(uniqueSourceGroupIDs, srcGroupID) - } - } - - // 校验源分组的平台是否与当前分组一致 - for _, srcGroupID := range uniqueSourceGroupIDs { - srcGroup, err := s.groupRepo.GetByIDLite(ctx, srcGroupID) - if err != nil { - return nil, fmt.Errorf("source group %d not found: %w", srcGroupID, err) - } - if srcGroup.Platform != group.Platform { - return nil, fmt.Errorf("source group %d platform mismatch: expected %s, got %s", srcGroupID, group.Platform, srcGroup.Platform) - } - } - - // 获取所有源分组的账号(去重) - accountIDsToCopy, err := s.groupRepo.GetAccountIDsByGroupIDs(ctx, uniqueSourceGroupIDs) - if err != nil { - return nil, fmt.Errorf("failed to get accounts from source groups: %w", err) - } - - // 先清空当前分组的所有账号绑定 - if _, err := s.groupRepo.DeleteAccountGroupsByGroupID(ctx, id); err != nil { - return nil, fmt.Errorf("failed to clear existing account bindings: %w", err) - } - - // require_oauth_only: 过滤掉 apikey 类型账号 - if group.RequireOAuthOnly && (group.Platform == PlatformOpenAI || group.Platform == PlatformAntigravity || group.Platform == PlatformAnthropic || group.Platform == PlatformGemini) && len(accountIDsToCopy) > 0 { - accounts, err := s.accountRepo.GetByIDs(ctx, accountIDsToCopy) - if err != nil { - return nil, fmt.Errorf("failed to fetch accounts for oauth filter: %w", err) - } - oauthIDs := make(map[int64]struct{}, len(accounts)) - for _, acc := range accounts { - if acc.Type != AccountTypeAPIKey { - oauthIDs[acc.ID] = struct{}{} - } - } - var filtered []int64 - for _, aid := range accountIDsToCopy { - if _, ok := oauthIDs[aid]; ok { - filtered = append(filtered, aid) - } - } - accountIDsToCopy = filtered - } - - // 再绑定源分组的账号 - if len(accountIDsToCopy) > 0 { - if err := s.groupRepo.BindAccountsToGroup(ctx, id, accountIDsToCopy); err != nil { - return nil, fmt.Errorf("failed to bind accounts to group: %w", err) - } - } - } - - return group, nil -} - -func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { - var groupKeys []string - if s.authCacheInvalidator != nil { - keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) - if err == nil { - groupKeys = keys + if fallbackGroup.FallbackGroupID == nil { + return nil } + nextID = *fallbackGroup.FallbackGroupID } +} - affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) - if err != nil { - return err +// validateFallbackGroupOnInvalidRequest 校验无效请求兜底分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// platform/subscriptionType: 当前分组的有效平台/订阅类型 +// fallbackGroupID: 兜底分组 ID +func (s *adminServiceImpl) validateFallbackGroupOnInvalidRequest(ctx context.Context, currentGroupID int64, platform, subscriptionType string, fallbackGroupID int64) error { + if platform != PlatformAnthropic && platform != PlatformAntigravity { + return fmt.Errorf("invalid request fallback only supported for anthropic or antigravity groups") } - // 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理 - - // 事务成功后,异步失效受影响用户的订阅缓存 - if len(affectedUserIDs) > 0 && s.billingCacheService != nil { - groupID := id - go func() { - cacheCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - for _, userID := range affectedUserIDs { - if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { - logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) - } - } - }() + if subscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("subscription groups cannot set invalid request fallback") } - if s.authCacheInvalidator != nil { - for _, key := range groupKeys { - s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) - } + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as invalid request fallback group") } - return nil -} - -func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize} - keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) + fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, fallbackGroupID) if err != nil { - return nil, 0, err + return fmt.Errorf("fallback group not found: %w", err) } - return keys, result.Total, nil -} - -func (s *adminServiceImpl) GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) { - if s.userGroupRateRepo == nil { - return nil, nil + if fallbackGroup.Platform != PlatformAnthropic { + return fmt.Errorf("fallback group must be anthropic platform") + } + if fallbackGroup.SubscriptionType == SubscriptionTypeSubscription { + return fmt.Errorf("fallback group cannot be subscription type") + } + if fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { + return fmt.Errorf("fallback group cannot have invalid request fallback configured") } - return s.userGroupRateRepo.GetByGroupID(ctx, groupID) + return nil } func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error { @@ -2079,18 +980,6 @@ func (s *adminServiceImpl) ClearGroupRateMultipliers(ctx context.Context, groupI return s.userGroupRateRepo.DeleteByGroupID(ctx, groupID) } -func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error { - if s.userGroupRateRepo == nil { - return nil - } - for _, e := range entries { - if e.RateMultiplier <= 0 { - return fmt.Errorf("rate_multiplier must be > 0 (user_id=%d)", e.UserID) - } - } - return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) -} - func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error { if s.userGroupRateRepo == nil { return nil @@ -2105,134 +994,10 @@ func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID i return nil } -func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error { - if s.userGroupRateRepo == nil { - return nil - } - for _, e := range entries { - if e.RPMOverride != nil && *e.RPMOverride < 0 { - return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID)) - } - } - if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil { - return err - } - // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) - } - return nil -} - func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { return s.groupRepo.UpdateSortOrders(ctx, updates) } -// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定 -// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组 -func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) { - apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) - if err != nil { - return nil, err - } - - if groupID == nil { - // nil 表示不修改,直接返回 - return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil - } - - if *groupID < 0 { - return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative") - } - - result := &AdminUpdateAPIKeyGroupIDResult{} - - if *groupID == 0 { - // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key) - apiKey.GroupID = nil - apiKey.Group = nil - } else { - // 验证目标分组存在且状态为 active - group, err := s.groupRepo.GetByID(ctx, *groupID) - if err != nil { - return nil, err - } - if group.Status != StatusActive { - return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") - } - // 订阅类型分组:用户须持有该分组的有效订阅才可绑定 - if group.IsSubscriptionType() { - if s.userSubRepo == nil { - return nil, infraerrors.InternalServer("SUBSCRIPTION_REPOSITORY_UNAVAILABLE", "subscription repository is not configured") - } - if _, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, apiKey.UserID, *groupID); err != nil { - if errors.Is(err, ErrSubscriptionNotFound) { - return nil, infraerrors.BadRequest("SUBSCRIPTION_REQUIRED", "user does not have an active subscription for this group") - } - return nil, err - } - } - - gid := *groupID - apiKey.GroupID = &gid - apiKey.Group = group - - // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 - if group.IsExclusive && !group.IsSubscriptionType() { - opCtx := ctx - var tx *dbent.Tx - if s.entClient == nil { - logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding") - } else { - var txErr error - tx, txErr = s.entClient.Tx(ctx) - if txErr != nil { - return nil, fmt.Errorf("begin transaction: %w", txErr) - } - defer func() { _ = tx.Rollback() }() - opCtx = dbent.NewTxContext(ctx, tx) - } - - if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil { - return nil, fmt.Errorf("add group to user allowed groups: %w", addErr) - } - if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil { - return nil, fmt.Errorf("update api key: %w", err) - } - if tx != nil { - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("commit transaction: %w", err) - } - } - - result.AutoGrantedGroupAccess = true - result.GrantedGroupID = &gid - result.GrantedGroupName = group.Name - - // 失效认证缓存(在事务提交后执行) - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) - } - - result.APIKey = apiKey - return result, nil - } - } - - // 非专属分组 / 解绑:无需事务,单步更新即可 - if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { - return nil, fmt.Errorf("update api key: %w", err) - } - - // 失效认证缓存 - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) - } - - result.APIKey = apiKey - return result, nil -} - // AdminResetAPIKeyRateLimitUsage resets all API key rate-limit usage windows. func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, keyID int64) (*APIKey, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) @@ -2257,71 +1022,6 @@ func (s *adminServiceImpl) AdminResetAPIKeyRateLimitUsage(ctx context.Context, k return apiKey, nil } -// ReplaceUserGroup 替换用户的专属分组 -func (s *adminServiceImpl) ReplaceUserGroup(ctx context.Context, userID, oldGroupID, newGroupID int64) (*ReplaceUserGroupResult, error) { - if oldGroupID == newGroupID { - return nil, infraerrors.BadRequest("SAME_GROUP", "old and new group must be different") - } - - // 验证新分组存在且为活跃的专属标准分组 - newGroup, err := s.groupRepo.GetByID(ctx, newGroupID) - if err != nil { - return nil, err - } - if newGroup.Status != StatusActive { - return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") - } - if !newGroup.IsExclusive { - return nil, infraerrors.BadRequest("GROUP_NOT_EXCLUSIVE", "target group is not exclusive") - } - if newGroup.IsSubscriptionType() { - return nil, infraerrors.BadRequest("GROUP_IS_SUBSCRIPTION", "subscription groups are not supported for replacement") - } - - // 事务保证原子性 - if s.entClient == nil { - return nil, fmt.Errorf("entClient is nil, cannot perform group replacement") - } - tx, err := s.entClient.Tx(ctx) - if err != nil { - return nil, fmt.Errorf("begin transaction: %w", err) - } - defer func() { _ = tx.Rollback() }() - opCtx := dbent.NewTxContext(ctx, tx) - - // 1. 授予新分组权限 - if err := s.userRepo.AddGroupToAllowedGroups(opCtx, userID, newGroupID); err != nil { - return nil, fmt.Errorf("add new group to allowed groups: %w", err) - } - - // 2. 迁移绑定旧分组的 Key 到新分组 - migrated, err := s.apiKeyRepo.UpdateGroupIDByUserAndGroup(opCtx, userID, oldGroupID, newGroupID) - if err != nil { - return nil, fmt.Errorf("migrate api keys: %w", err) - } - - // 3. 移除旧分组权限 - if err := s.userRepo.RemoveGroupFromUserAllowedGroups(opCtx, userID, oldGroupID); err != nil { - return nil, fmt.Errorf("remove old group from allowed groups: %w", err) - } - - if err := tx.Commit(); err != nil { - return nil, fmt.Errorf("commit transaction: %w", err) - } - - // 失效该用户所有 Key 的认证缓存 - if s.authCacheInvalidator != nil { - keys, keyErr := s.apiKeyRepo.ListKeysByUserID(ctx, userID) - if keyErr == nil { - for _, k := range keys { - s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, k) - } - } - } - - return &ReplaceUserGroupResult{MigratedKeys: migrated}, nil -} - // Account management implementations func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} @@ -2951,224 +1651,6 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po return s.proxyRepo.ExistsByHostPortAuth(ctx, host, port, username, password) } -// Redeem code management implementations -func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string, sortBy, sortOrder string) ([]RedeemCode, int64, error) { - params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} - codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) - if err != nil { - return nil, 0, err - } - return codes, result.Total, nil -} - -func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { - return s.redeemCodeRepo.GetByID(ctx, id) -} - -func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { - // 如果是订阅类型,验证必须有 GroupID - if input.Type == RedeemTypeSubscription { - if input.GroupID == nil { - return nil, errors.New("group_id is required for subscription type") - } - // 验证分组存在且为订阅类型 - group, err := s.groupRepo.GetByID(ctx, *input.GroupID) - if err != nil { - return nil, fmt.Errorf("group not found: %w", err) - } - if !group.IsSubscriptionType() { - return nil, errors.New("group must be subscription type") - } - } - - codes := make([]RedeemCode, 0, input.Count) - for i := 0; i < input.Count; i++ { - codeValue, err := GenerateRedeemCode() - if err != nil { - return nil, err - } - code := RedeemCode{ - Code: codeValue, - Type: input.Type, - Value: input.Value, - Status: StatusUnused, - } - // 订阅类型专用字段 - if input.Type == RedeemTypeSubscription { - code.GroupID = input.GroupID - code.ValidityDays = input.ValidityDays - if code.ValidityDays <= 0 { - code.ValidityDays = 30 // 默认30天 - } - } - if err := s.redeemCodeRepo.Create(ctx, &code); err != nil { - return nil, err - } - codes = append(codes, code) - } - return codes, nil -} - -func (s *adminServiceImpl) DeleteRedeemCode(ctx context.Context, id int64) error { - return s.redeemCodeRepo.Delete(ctx, id) -} - -func (s *adminServiceImpl) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { - var deleted int64 - for _, id := range ids { - if err := s.redeemCodeRepo.Delete(ctx, id); err == nil { - deleted++ - } - } - return deleted, nil -} - -func (s *adminServiceImpl) ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error) { - code, err := s.redeemCodeRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - code.Status = StatusExpired - if err := s.redeemCodeRepo.Update(ctx, code); err != nil { - return nil, err - } - return code, nil -} - -func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) { - proxy, err := s.proxyRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - proxyURL := proxy.URL() - exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) - if err != nil { - s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ - Success: false, - Message: err.Error(), - UpdatedAt: time.Now(), - }) - return &ProxyTestResult{ - Success: false, - Message: err.Error(), - }, nil - } - - latency := latencyMs - s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{ - Success: true, - LatencyMs: &latency, - Message: "Proxy is accessible", - IPAddress: exitInfo.IP, - Country: exitInfo.Country, - CountryCode: exitInfo.CountryCode, - Region: exitInfo.Region, - City: exitInfo.City, - UpdatedAt: time.Now(), - }) - return &ProxyTestResult{ - Success: true, - Message: "Proxy is accessible", - LatencyMs: latencyMs, - IPAddress: exitInfo.IP, - City: exitInfo.City, - Region: exitInfo.Region, - Country: exitInfo.Country, - CountryCode: exitInfo.CountryCode, - }, nil -} - -func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { - proxy, err := s.proxyRepo.GetByID(ctx, id) - if err != nil { - return nil, err - } - - result := &ProxyQualityCheckResult{ - ProxyID: id, - Score: 100, - Grade: "A", - CheckedAt: time.Now().Unix(), - Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), - } - - proxyURL := proxy.URL() - if s.proxyProber == nil { - result.Items = append(result.Items, ProxyQualityCheckItem{ - Target: "base_connectivity", - Status: "fail", - Message: "代理探测服务未配置", - }) - result.FailedCount++ - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, nil) - return result, nil - } - - exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) - if err != nil { - result.Items = append(result.Items, ProxyQualityCheckItem{ - Target: "base_connectivity", - Status: "fail", - LatencyMs: latencyMs, - Message: err.Error(), - }) - result.FailedCount++ - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, nil) - return result, nil - } - - result.ExitIP = exitInfo.IP - result.Country = exitInfo.Country - result.CountryCode = exitInfo.CountryCode - result.BaseLatencyMs = latencyMs - result.Items = append(result.Items, ProxyQualityCheckItem{ - Target: "base_connectivity", - Status: "pass", - LatencyMs: latencyMs, - Message: "代理出口连通正常", - }) - result.PassedCount++ - - client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: proxyQualityRequestTimeout, - ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, - }) - if err != nil { - result.Items = append(result.Items, ProxyQualityCheckItem{ - Target: "http_client", - Status: "fail", - Message: fmt.Sprintf("创建检测客户端失败: %v", err), - }) - result.FailedCount++ - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) - return result, nil - } - - for _, target := range proxyQualityTargets { - item := runProxyQualityTarget(ctx, client, target) - result.Items = append(result.Items, item) - switch item.Status { - case "pass": - result.PassedCount++ - case "warn": - result.WarnCount++ - case "challenge": - result.ChallengeCount++ - default: - result.FailedCount++ - } - } - - finalizeProxyQualityResult(result) - s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) - return result, nil -} - func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { item := ProxyQualityCheckItem{ Target: target.Target, @@ -3312,65 +1794,6 @@ func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool { return false } -func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { - if result == nil { - return - } - score := result.Score - checkedAt := result.CheckedAt - info := &ProxyLatencyInfo{ - Success: proxyQualityBaseConnectivityPass(result), - Message: result.Summary, - QualityStatus: proxyQualityOverallStatus(result), - QualityScore: &score, - QualityGrade: result.Grade, - QualitySummary: result.Summary, - QualityCheckedAt: &checkedAt, - QualityCFRay: proxyQualityFirstCFRay(result), - UpdatedAt: time.Now(), - } - if result.BaseLatencyMs > 0 { - latency := result.BaseLatencyMs - info.LatencyMs = &latency - } - if exitInfo != nil { - info.IPAddress = exitInfo.IP - info.Country = exitInfo.Country - info.CountryCode = exitInfo.CountryCode - info.Region = exitInfo.Region - info.City = exitInfo.City - } - s.saveProxyLatency(ctx, proxyID, info) -} - -func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { - if s.proxyProber == nil || proxy == nil { - return - } - exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL()) - if err != nil { - s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ - Success: false, - Message: err.Error(), - UpdatedAt: time.Now(), - }) - return - } - - latency := latencyMs - s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{ - Success: true, - LatencyMs: &latency, - Message: "Proxy is accessible", - IPAddress: exitInfo.IP, - Country: exitInfo.Country, - CountryCode: exitInfo.CountryCode, - Region: exitInfo.Region, - City: exitInfo.City, - UpdatedAt: time.Now(), - }) -} - // checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic) // 如果存在混合,返回错误提示用户确认 func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { @@ -3454,76 +1877,6 @@ func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAcc return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) } -func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) { - if s.proxyLatencyCache == nil || len(proxies) == 0 { - return - } - - ids := make([]int64, 0, len(proxies)) - for i := range proxies { - ids = append(ids, proxies[i].ID) - } - - latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) - if err != nil { - logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err) - return - } - - for i := range proxies { - info := latencies[proxies[i].ID] - if info == nil { - continue - } - if info.Success { - proxies[i].LatencyStatus = "success" - proxies[i].LatencyMs = info.LatencyMs - } else { - proxies[i].LatencyStatus = "failed" - } - proxies[i].LatencyMessage = info.Message - proxies[i].IPAddress = info.IPAddress - proxies[i].Country = info.Country - proxies[i].CountryCode = info.CountryCode - proxies[i].Region = info.Region - proxies[i].City = info.City - proxies[i].QualityStatus = info.QualityStatus - proxies[i].QualityScore = info.QualityScore - proxies[i].QualityGrade = info.QualityGrade - proxies[i].QualitySummary = info.QualitySummary - proxies[i].QualityChecked = info.QualityCheckedAt - } -} - -func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) { - if s.proxyLatencyCache == nil || info == nil { - return - } - - merged := *info - if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { - if existing := latencies[proxyID]; existing != nil { - if merged.QualityCheckedAt == nil && - merged.QualityScore == nil && - merged.QualityGrade == "" && - merged.QualityStatus == "" && - merged.QualitySummary == "" && - merged.QualityCFRay == "" { - merged.QualityStatus = existing.QualityStatus - merged.QualityScore = existing.QualityScore - merged.QualityGrade = existing.QualityGrade - merged.QualitySummary = existing.QualitySummary - merged.QualityCheckedAt = existing.QualityCheckedAt - merged.QualityCFRay = existing.QualityCFRay - } - } - } - - if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { - logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) - } -} - // getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识 func getAccountPlatform(accountPlatform string) string { switch strings.ToLower(strings.TrimSpace(accountPlatform)) { diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index 3b3dbc21cfb..8c73db2a7ce 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -69,8 +69,12 @@ func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, i panic("unexpected") } -func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} +func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go index c791b747cf7..3036bb1c5f0 100644 --- a/backend/internal/service/admin_service_email_identity_sync_test.go +++ b/backend/internal/service/admin_service_email_identity_sync_test.go @@ -113,8 +113,12 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) return 0, nil } -func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} +func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } diff --git a/backend/internal/service/admin_user.go b/backend/internal/service/admin_user.go new file mode 100644 index 00000000000..4955f86e364 --- /dev/null +++ b/backend/internal/service/admin_user.go @@ -0,0 +1,499 @@ +package service + +import ( + "context" + "errors" + "fmt" + "sort" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/authidentity" + "github.com/Wei-Shaw/sub2api/ent/authidentitychannel" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// User management implementations +func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters, sortBy, sortOrder string) ([]User, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} + users, result, err := s.userRepo.ListWithFilters(ctx, params, filters) + if err != nil { + return nil, 0, err + } + if len(users) > 0 { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) + } + lastUsedByUserID, latestErr := s.userRepo.GetLatestUsedAtByUserIDs(ctx, userIDs) + if latestErr != nil { + logger.LegacyPrintf("service.admin", "failed to load user last_used_at in batch: err=%v", latestErr) + } else { + for i := range users { + users[i].LastUsedAt = lastUsedByUserID[users[i].ID] + } + } + } + // 批量加载用户专属分组倍率 + if s.userGroupRateRepo != nil && len(users) > 0 { + if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) + } + ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) + s.loadUserGroupRatesOneByOne(ctx, users) + } else { + for i := range users { + if rates, ok := ratesByUser[users[i].ID]; ok { + users[i].GroupRates = rates + } + } + } + } else { + s.loadUserGroupRatesOneByOne(ctx, users) + } + } + return users, result.Total, nil +} + +func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + lastUsedAt, latestErr := s.userRepo.GetLatestUsedAtByUserID(ctx, id) + if latestErr != nil { + logger.LegacyPrintf("service.admin", "failed to load user last_used_at: user_id=%d err=%v", id, latestErr) + } else { + user.LastUsedAt = lastUsedAt + } + // 加载用户专属分组倍率 + if s.userGroupRateRepo != nil { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err) + } else { + user.GroupRates = rates + } + } + return user, nil +} + +func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { + user := &User{ + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + RPMLimit: input.RPMLimit, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, + } + if err := user.SetPassword(input.Password); err != nil { + return nil, err + } + if err := s.userRepo.Create(ctx, user); err != nil { + return nil, err + } + s.assignDefaultSubscriptions(ctx, user.ID) + return user, nil +} + +func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { + // 校验用户专属分组倍率:必须 > 0(nil 合法,表示清除专属倍率) + if input.GroupRates != nil { + for groupID, rate := range input.GroupRates { + if rate != nil && *rate <= 0 { + return nil, fmt.Errorf("rate_multiplier must be > 0 (group_id=%d)", groupID) + } + } + } + + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + // Protect admin users: cannot disable admin accounts + if user.Role == "admin" && input.Status == "disabled" { + return nil, errors.New("cannot disable admin user") + } + + oldConcurrency := user.Concurrency + oldStatus := user.Status + oldRole := user.Role + oldRPMLimit := user.RPMLimit + + if input.Email != "" { + user.Email = input.Email + } + if input.Password != "" { + if err := user.SetPassword(input.Password); err != nil { + return nil, err + } + } + + if input.Username != nil { + user.Username = *input.Username + } + if input.Notes != nil { + user.Notes = *input.Notes + } + + if input.Status != "" { + user.Status = input.Status + } + + if input.Concurrency != nil { + user.Concurrency = *input.Concurrency + } + + if input.RPMLimit != nil { + user.RPMLimit = *input.RPMLimit + } + + if input.AllowedGroups != nil { + user.AllowedGroups = *input.AllowedGroups + } + + if err := s.userRepo.Update(ctx, user); err != nil { + return nil, err + } + + // 同步用户专属分组倍率 + if input.GroupRates != nil && s.userGroupRateRepo != nil { + if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { + logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err) + } + } + + if s.authCacheInvalidator != nil { + // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联, + // 不失效缓存会让修改在一个 L2 TTL 内失去效果。 + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) + } + } + + concurrencyDiff := user.Concurrency - oldConcurrency + if concurrencyDiff != 0 { + code, err := GenerateRedeemCode() + if err != nil { + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) + return user, nil + } + adjustmentRecord := &RedeemCode{ + Code: code, + Type: AdjustmentTypeAdminConcurrency, + Value: float64(concurrencyDiff), + Status: StatusUsed, + UsedBy: &user.ID, + } + now := time.Now() + adjustmentRecord.UsedAt = &now + if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { + logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err) + } + } + + return user, nil +} + +func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { + // Protect admin users: cannot delete admin accounts + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return err + } + if user.Role == "admin" { + return errors.New("cannot delete admin user") + } + if err := s.userRepo.Delete(ctx, id); err != nil { + logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err) + return err + } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) + } + return nil +} + +func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize, SortBy: sortBy, SortOrder: sortOrder} + keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, APIKeyListFilters{}) + if err != nil { + return nil, 0, err + } + return keys, result.Total, nil +} + +func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) { + if s.userRPMCache == nil { + return nil, ErrRPMStatusUnavailable + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + + userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err) + } + + keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "") + if err != nil { + return nil, err + } + + groupIDSet := make(map[int64]struct{}) + for _, key := range keys { + if key.GroupID != nil && *key.GroupID > 0 { + groupIDSet[*key.GroupID] = struct{}{} + } + } + + groupIDs := make([]int64, 0, len(groupIDSet)) + for groupID := range groupIDSet { + groupIDs = append(groupIDs, groupID) + } + sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] }) + + var perGroup []UserGroupRPMStatus + for _, groupID := range groupIDs { + used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID) + if getErr != nil { + logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr) + } + + entry := UserGroupRPMStatus{ + GroupID: groupID, + Used: used, + } + + if s.groupRepo != nil { + if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil { + entry.GroupName = group.Name + entry.Limit = group.RPMLimit + entry.Source = "group" + } else if groupErr != nil { + logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr) + } + } + + if s.userGroupRateRepo != nil { + override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID) + if overrideErr != nil { + logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr) + } else if override != nil { + entry.Limit = *override + entry.Source = "override" + } + } + + perGroup = append(perGroup, entry) + } + + return &UserRPMStatus{ + UserRPMUsed: userRPMUsed, + UserRPMLimit: user.RPMLimit, + PerGroup: perGroup, + }, nil +} + +func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { + // Return mock data for now + return map[string]any{ + "period": period, + "total_requests": 0, + "total_cost": 0.0, + "total_tokens": 0, + "avg_duration_ms": 0, + }, nil +} + +// GetUserBalanceHistory returns paginated balance/concurrency change records for a user. +func (s *adminServiceImpl) GetUserBalanceHistory(ctx context.Context, userID int64, page, pageSize int, codeType string) ([]RedeemCode, int64, float64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + if codeType == RedeemTypeAffiliateBalance { + codes, total, err := s.listAffiliateBalanceHistory(ctx, userID, params) + if err != nil { + return nil, 0, 0, err + } + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, total, totalRecharged, nil + } + + if codeType == "" { + return s.getAllUserBalanceHistory(ctx, userID, params) + } + + codes, result, err := s.redeemCodeRepo.ListByUserPaginated(ctx, userID, params, codeType) + if err != nil { + return nil, 0, 0, err + } + total := result.Total + // Aggregate total recharged amount (only once, regardless of type filter) + totalRecharged, err := s.redeemCodeRepo.SumPositiveBalanceByUser(ctx, userID) + if err != nil { + return nil, 0, 0, err + } + return codes, total, totalRecharged, nil +} + +func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int64, input AdminBindAuthIdentityInput) (*AdminBoundAuthIdentity, error) { + if userID <= 0 { + return nil, infraerrors.BadRequest("INVALID_INPUT", "user_id must be greater than 0") + } + if s == nil || s.entClient == nil || s.userRepo == nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_UNAVAILABLE", "auth identity binding service is unavailable") + } + if _, err := s.userRepo.GetByID(ctx, userID); err != nil { + return nil, err + } + + providerType := normalizeAdminAuthIdentityProviderType(input.ProviderType) + providerKey := strings.TrimSpace(input.ProviderKey) + providerSubject := strings.TrimSpace(input.ProviderSubject) + if providerType == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") + } + if providerKey == "" || providerSubject == "" { + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") + } + canonicalProviderKey := canonicalAdminAuthIdentityProviderKey(providerType, "", providerKey) + compatibleProviderKeys := compatibleAdminAuthIdentityProviderKeys(providerType, providerKey) + + var issuer *string + if input.Issuer != nil { + trimmed := strings.TrimSpace(*input.Issuer) + if trimmed != "" { + issuer = &trimmed + } + } + + channelInput := normalizeAdminBindChannelInput(input.Channel) + if input.Channel != nil && channelInput == nil { + return nil, infraerrors.BadRequest("INVALID_INPUT", "channel, channel_app_id, and channel_subject are required when channel binding is provided") + } + + verifiedAt := time.Now().UTC() + tx, err := s.entClient.Tx(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_TX_FAILED", "failed to start auth identity bind transaction").WithCause(err) + } + defer func() { _ = tx.Rollback() }() + + identityRecords, err := tx.AuthIdentity.Query(). + Where( + authidentity.ProviderTypeEQ(providerType), + authidentity.ProviderKeyIn(compatibleProviderKeys...), + authidentity.ProviderSubjectEQ(providerSubject), + ). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_LOOKUP_FAILED", "failed to inspect auth identity ownership").WithCause(err) + } + if hasAdminAuthIdentityOwnershipConflict(identityRecords, userID) { + return nil, infraerrors.Conflict("AUTH_IDENTITY_OWNERSHIP_CONFLICT", "auth identity already belongs to another user") + } + identity := selectOwnedAdminAuthIdentity(identityRecords, userID) + + if identity == nil { + create := tx.AuthIdentity.Create(). + SetUserID(userID). + SetProviderType(providerType). + SetProviderKey(canonicalProviderKey). + SetProviderSubject(providerSubject). + SetVerifiedAt(verifiedAt) + if issuer != nil { + create = create.SetIssuer(*issuer) + } + if input.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } else { + update := tx.AuthIdentity.UpdateOneID(identity.ID). + SetVerifiedAt(verifiedAt). + SetProviderKey(canonicalProviderKey) + if issuer != nil { + update = update.SetIssuer(*issuer) + } + if input.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(input.Metadata)) + } + identity, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_SAVE_FAILED", "failed to save auth identity").WithCause(err) + } + } + + var channel *dbent.AuthIdentityChannel + if channelInput != nil { + channelRecords, err := tx.AuthIdentityChannel.Query(). + Where( + authidentitychannel.ProviderTypeEQ(providerType), + authidentitychannel.ProviderKeyIn(compatibleProviderKeys...), + authidentitychannel.ChannelEQ(channelInput.Channel), + authidentitychannel.ChannelAppIDEQ(channelInput.ChannelAppID), + authidentitychannel.ChannelSubjectEQ(channelInput.ChannelSubject), + ). + WithIdentity(). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_LOOKUP_FAILED", "failed to inspect auth identity channel ownership").WithCause(err) + } + if hasAdminAuthIdentityChannelOwnershipConflict(channelRecords, userID) { + return nil, infraerrors.Conflict("AUTH_IDENTITY_CHANNEL_OWNERSHIP_CONFLICT", "auth identity channel already belongs to another user") + } + channel = selectOwnedAdminAuthIdentityChannel(channelRecords, userID) + if channel == nil { + create := tx.AuthIdentityChannel.Create(). + SetIdentityID(identity.ID). + SetProviderType(providerType). + SetProviderKey(canonicalProviderKey). + SetChannel(channelInput.Channel). + SetChannelAppID(channelInput.ChannelAppID). + SetChannelSubject(channelInput.ChannelSubject) + if channelInput.Metadata != nil { + create = create.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = create.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } else { + update := tx.AuthIdentityChannel.UpdateOneID(channel.ID). + SetIdentityID(identity.ID). + SetProviderKey(canonicalProviderKey) + if channelInput.Metadata != nil { + update = update.SetMetadata(cloneAdminAuthIdentityMetadata(channelInput.Metadata)) + } + channel, err = update.Save(ctx) + if err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_CHANNEL_SAVE_FAILED", "failed to save auth identity channel").WithCause(err) + } + } + } + + if err := tx.Commit(); err != nil { + return nil, infraerrors.InternalServer("ADMIN_AUTH_IDENTITY_BIND_COMMIT_FAILED", "failed to commit auth identity bind").WithCause(err) + } + return buildAdminBoundAuthIdentity(identity, channel), nil +} diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index e3b60a27902..6ac6b8fa72c 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -5,12 +5,13 @@ package service import ( "bytes" "context" - "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" - "github.com/stretchr/testify/require" "io" "net/http" "strings" "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/stretchr/testify/require" ) // stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index 8f03f857efa..1f4827a2295 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -810,8 +810,8 @@ func (s *emailBindUserRepoStub) UpdateUserLastActiveAt(context.Context, int64, t } func (s *emailBindUserRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } -func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } -func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *emailBindUserRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } +func (s *emailBindUserRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) (bool, error) { s.mu.Lock() @@ -820,8 +820,12 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) ( return ok, nil } -func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} +func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + return 0, nil +} func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/service/channel_monitor_template_types.go b/backend/internal/service/channel_monitor_template_types.go index e5bf7568443..06b4f3ab6c3 100644 --- a/backend/internal/service/channel_monitor_template_types.go +++ b/backend/internal/service/channel_monitor_template_types.go @@ -1,8 +1,9 @@ package service import ( - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) // ChannelMonitorRequestTemplate 请求模板(service 层模型)。 diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index 164861fb93d..26db59a7aaf 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -482,7 +482,6 @@ func TestSupportedModels_WildcardExpandedFromPricing(t *testing.T) { } } - func TestSupportedModels_MissingPricingKeepsNilPricing(t *testing.T) { ch := &Channel{ ModelMapping: map[string]map[string]string{ diff --git a/backend/internal/service/gateway_account_selection.go b/backend/internal/service/gateway_account_selection.go new file mode 100644 index 00000000000..8aaaebe28a9 --- /dev/null +++ b/backend/internal/service/gateway_account_selection.go @@ -0,0 +1,552 @@ +package service + +import ( + "context" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// accountWithLoad 账号与负载信息的组合,用于负载感知调度 +type accountWithLoad struct { + account *Account + loadInfo *AccountLoadInfo +} + +func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + return PrefetchedStickyGroupIDFromContext(ctx) +} + +func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { + prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) + if !ok || prefetchedGroupID != derefGroupID(groupID) { + return 0 + } + if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 { + return accountID + } + return 0 +} + +// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 +// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等), +// 额外检查模型级限流。 +// +// shouldClearStickySession checks if an account is in an unschedulable state +// and the sticky session binding should be cleared. +// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting. +func shouldClearStickySession(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if !account.IsSchedulable() { + return true + } + if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 { + return true + } + return false +} + +// SelectAccount 选择账号(粘性会话+优先级) +func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { + return s.SelectAccountForModel(ctx, groupID, sessionHash, "") +} + +// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) +func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { + return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) +} + +// SelectAccountWithLoadAwareness 基于负载感知选择可用账号。 +// 按优先级依次尝试:模型路由 → 粘性会话 → 负载均衡 → 排队等待。 +// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。 +func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { + // 阶段1: 初始化调度上下文(配置加载、Claude Code 限制、渠道定价、粘性绑定解析) + ctx, sc, err := s.prepareSchedulingContext(ctx, groupID, sessionHash, requestedModel, excludedIDs, metadataUserID, sub2apiUserID) + if err != nil { + return nil, err + } + + // 阶段2: 非负载感知快速路径(并发控制未启用时) + if s.concurrencyService == nil || !sc.cfg.LoadBatchEnabled { + return s.selectWithoutLoadAwareness(ctx, sc) + } + + // 阶段3: 加载可调度账号列表并构建查找结构 + ctx, pool, err := s.loadSchedulableAccountPool(ctx, sc) + if err != nil { + return nil, err + } + + // 阶段4: 模型路由层(Layer 1)— 按模型配置的专属账号选择 + if result, err := s.tryModelRoutingSelection(ctx, sc, pool); result != nil || err != nil { + return result, err + } + + // 阶段5: 粘性会话层(Layer 1.5)— 尝试复用上次绑定的账号 + if result, err := s.tryStickySessionSelection(ctx, sc, pool); result != nil || err != nil { + return result, err + } + + // 阶段6: 负载均衡层(Layer 2)— 从候选池中选最优 + if result, err := s.selectByLoadBalance(ctx, sc, pool); result != nil || err != nil { + return result, err + } + + // 阶段7: 排队等待层(Layer 3)— 所有账号满载时进入等待队列 + return s.selectWithQueueFallback(ctx, sc, pool) +} + +// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) +func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { + preferOAuth := platform == PlatformGemini + routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) + + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + + var accounts []Account + accountsLoaded := false + + // ============ Model Routing (legacy path): apply before sticky session ============ + // When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing + // so switching model can switch upstream account within the same sticky session. + if len(routingAccountIDs) > 0 { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) + } + // 1) Sticky session only applies if the bound account is within the routing set. + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) + if err == nil { + clearSticky := shouldClearStickySession(account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil + } + } + } + } + } + + // 2) Select an account from the routed candidates. + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform == "" { + hasForcePlatform = false + } + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + accountsLoaded = true + + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) + for _, id := range routingAccountIDs { + if id > 0 { + routingSet[id] = struct{}{} + } + } + + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, ok := routingSet[acc.ID]; !ok { + continue + } + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + } else if isBetterAccountCandidate(acc, selected, preferOAuth, "") { + selected = acc + } + } + + if selected != nil { + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + } + return selected, nil + } + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + } + + // 1. 查询粘性会话 + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) + if err == nil { + clearSticky := shouldClearStickySession(account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + return account, nil + } + } + } + } + } + + // 2. 获取可调度账号列表(单平台) + if !accountsLoaded { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform == "" { + hasForcePlatform = false + } + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + } + + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + // 3. 按优先级+最久未用选择(考虑模型支持) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, + // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + } else if isBetterAccountCandidate(acc, selected, preferOAuth, "") { + selected = acc + } + } + + if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) + if requestedModel != "" { + return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) + } + return nil, ErrNoAvailableAccounts + } + + // 4. 建立粘性绑定 + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + + return selected, nil +} + +// selectAccountWithMixedScheduling 选择账户(支持混合调度) +// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 +func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { + preferOAuth := nativePlatform == PlatformGemini + routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) + + // require_privacy_set: 获取分组信息 + var schedGroup *Group + if groupID != nil && s.groupRepo != nil { + schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) + } + + var accounts []Account + accountsLoaded := false + + // ============ Model Routing (legacy path): apply before sticky session ============ + if len(routingAccountIDs) > 0 { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) + } + // 1) Sticky session only applies if the bound account is within the routing set. + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 + if err == nil { + clearSticky := shouldClearStickySession(account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + } + return account, nil + } + } + } + } + } + } + + // 2) Select an account from the routed candidates. + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + accountsLoaded = true + + // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + routingSet := make(map[int64]struct{}, len(routingAccountIDs)) + for _, id := range routingAccountIDs { + if id > 0 { + routingSet[id] = struct{}{} + } + } + + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, ok := routingSet[acc.ID]; !ok { + continue + } + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } + // 过滤:原生平台直接通过,antigravity 需要启用混合调度 + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + } else if isBetterAccountCandidate(acc, selected, preferOAuth, PlatformGemini) { + selected = acc + } + } + + if selected != nil { + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + } + return selected, nil + } + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + } + + // 1. 查询粘性会话 + if sessionHash != "" && s.cache != nil { + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err == nil && accountID > 0 { + if _, excluded := excludedIDs[accountID]; !excluded { + account, err := s.getSchedulableAccount(ctx, accountID) + // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 + if err == nil { + clearSticky := shouldClearStickySession(account, requestedModel) + if clearSticky { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + } + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { + if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { + return account, nil + } + } + } + } + } + } + + // 2. 获取可调度账号列表 + if !accountsLoaded { + var err error + accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + } + + // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) + // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) + var selected *Account + for i := range accounts { + acc := &accounts[i] + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + // Scheduler snapshots can be temporarily stale; re-check schedulability here to + // avoid selecting accounts that were recently rate-limited/overloaded. + if !s.isAccountSchedulableForSelection(acc) { + continue + } + // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 + if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { + _ = s.accountRepo.SetError(ctx, acc.ID, + fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) + continue + } + // 过滤:原生平台直接通过,antigravity 需要启用混合调度 + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + continue + } + if !s.isAccountSchedulableForQuota(acc) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { + continue + } + if !s.isAccountSchedulableForRPM(ctx, acc, false) { + continue + } + if selected == nil { + selected = acc + } else if isBetterAccountCandidate(acc, selected, preferOAuth, PlatformGemini) { + selected = acc + } + } + + if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) + if requestedModel != "" { + return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) + } + return nil, ErrNoAvailableAccounts + } + + // 4. 建立粘性绑定 + if sessionHash != "" && s.cache != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + } + } + + return selected, nil +} diff --git a/backend/internal/service/gateway_account_selection_helpers.go b/backend/internal/service/gateway_account_selection_helpers.go new file mode 100644 index 00000000000..e614f7fe445 --- /dev/null +++ b/backend/internal/service/gateway_account_selection_helpers.go @@ -0,0 +1,745 @@ +package service + +import ( + "context" + "fmt" + "log/slog" + "sort" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// ───────────────────────────────────────────────────────────────────────────── +// 账号调度子方法:将 SelectAccountWithLoadAwareness 的内部逻辑分层组织 +// ───────────────────────────────────────────────────────────────────────────── + +// schedulingContext 封装账号调度所需的共享上下文。 +// 在 prepareSchedulingContext 中一次性初始化,避免子方法间传递大量参数。 +type schedulingContext struct { + cfg config.GatewaySchedulingConfig // 调度配置(超时、队列限制等) + group *Group // 用户所属分组 + groupID *int64 // 分组ID指针(可能被 Claude Code 限制修改) + stickyAccountID int64 // 粘性绑定的账号ID(0 表示无绑定) + stickySource string // 粘性来源: "prefetch" | "cache" | "" + requestedModel string // 请求的模型名称 + sessionHash string // 会话哈希(用于粘性绑定) + excludedIDs map[int64]struct{} // 已排除的账号集合(failover 后累加) + metadataUserID string // 客户端元数据用户ID + sub2apiUserID int64 // 系统用户ID +} + +// accountPool 缓存可调度账号列表和快速查找结构。 +// 在 loadSchedulableAccountPool 中构建,供后续各层选择使用。 +type accountPool struct { + accounts []Account // 可调度账号完整列表 + byID map[int64]*Account // 按ID快速查找 + useMixed bool // 是否混合调度模式(跨平台) + platform string // 目标平台 (anthropic/openai/gemini) + preferOAuth bool // 是否优先选择 OAuth 账号 + routingIDs []int64 // 模型路由指定的账号ID列表 + isExcluded func(int64) bool // 排除判断闭包 +} + +// prepareSchedulingContext 初始化调度上下文。 +// 包括:加载配置、检查 Claude Code 限制、检查渠道定价限制、解析粘性绑定。 +func (s *GatewayService) prepareSchedulingContext( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + metadataUserID string, + sub2apiUserID int64, +) (context.Context, *schedulingContext, error) { + excludedIDsList := make([]int64, 0, len(excludedIDs)) + for id := range excludedIDs { + excludedIDsList = append(excludedIDsList, id) + } + slog.Debug("account_scheduling_starting", + "group_id", derefGroupID(groupID), + "model", requestedModel, + "session", shortSessionHash(sessionHash), + "excluded_ids", excludedIDsList) + + cfg := s.schedulingConfig() + + // 检查 Claude Code 客户端限制(可能替换 groupID 为降级分组) + group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) + if err != nil { + return ctx, nil, err + } + ctx = s.withGroupContext(ctx, group) + + // 渠道定价限制预检查(必须使用解析后的分组) + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return ctx, nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + + // 解析粘性绑定(优先从 context prefetch,其次从 Redis 缓存) + var stickyAccountID int64 + var stickySource string + if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { + stickyAccountID = prefetch + stickySource = "prefetch" + } else if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { + stickyAccountID = accountID + stickySource = "cache" + } + } + + slog.Info("sticky.scheduler_entry", + "group_id", derefGroupID(groupID), + "session_hash", shortSessionHash(sessionHash), + "sticky_account_id", stickyAccountID, + "sticky_source", stickySource, + "model", requestedModel, + "load_batch", cfg.LoadBatchEnabled, + "has_concurrency_svc", s.concurrencyService != nil, + "excluded_count", len(excludedIDs), + ) + + if s.debugModelRoutingEnabled() && requestedModel != "" { + groupPlatform := "" + if group != nil { + groupPlatform = group.Platform + } + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", + derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) + } + + sc := &schedulingContext{ + cfg: cfg, + group: group, + groupID: groupID, + stickyAccountID: stickyAccountID, + stickySource: stickySource, + requestedModel: requestedModel, + sessionHash: sessionHash, + excludedIDs: excludedIDs, + metadataUserID: metadataUserID, + sub2apiUserID: sub2apiUserID, + } + return ctx, sc, nil +} + +// selectWithoutLoadAwareness 非负载感知的快速选择路径。 +// 当 concurrencyService 为 nil 或 LoadBatchEnabled=false 时使用。 +func (s *GatewayService) selectWithoutLoadAwareness(ctx context.Context, sc *schedulingContext) (*AccountSelectionResult, error) { + localExcluded := make(map[int64]struct{}) + for k, v := range sc.excludedIDs { + localExcluded[k] = v + } + + for { + account, err := s.SelectAccountForModelWithExclusions(ctx, sc.groupID, sc.sessionHash, sc.requestedModel, localExcluded) + if err != nil { + return nil, err + } + + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + if !s.checkAndRegisterSession(ctx, account, sc.sessionHash) { + result.ReleaseFunc() + localExcluded[account.ID] = struct{}{} + continue + } + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + } + + if !s.checkAndRegisterSession(ctx, account, sc.sessionHash) { + localExcluded[account.ID] = struct{}{} + continue + } + + // 粘性账号优先使用粘性等待超时 + if sc.stickyAccountID > 0 && sc.stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < sc.cfg.StickySessionMaxWaiting { + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: sc.cfg.StickySessionWaitTimeout, + MaxWaiting: sc.cfg.StickySessionMaxWaiting, + }) + } + } + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: sc.cfg.FallbackWaitTimeout, + MaxWaiting: sc.cfg.FallbackMaxWaiting, + }) + } +} + +// loadSchedulableAccountPool 加载可调度账号列表并构建查找结构。 +func (s *GatewayService) loadSchedulableAccountPool(ctx context.Context, sc *schedulingContext) (context.Context, *accountPool, error) { + platform, hasForcePlatform, err := s.resolvePlatform(ctx, sc.groupID, sc.group) + if err != nil { + return ctx, nil, err + } + preferOAuth := platform == PlatformGemini + + if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && sc.requestedModel != "" { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", + derefGroupID(sc.groupID), sc.requestedModel, shortSessionHash(sc.sessionHash), platform) + } + + accounts, useMixed, err := s.listSchedulableAccounts(ctx, sc.groupID, platform, hasForcePlatform) + if err != nil { + return ctx, nil, err + } + if len(accounts) == 0 { + return ctx, nil, ErrNoAvailableAccounts + } + + // 预取窗口费用和 RPM 数据(批量 Redis 查询,避免 N+1) + ctx = s.withWindowCostPrefetch(ctx, accounts) + ctx = s.withRPMPrefetch(ctx, accounts) + + byID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + byID[accounts[i].ID] = &accounts[i] + } + + isExcluded := func(accountID int64) bool { + if sc.excludedIDs == nil { + return false + } + _, excluded := sc.excludedIDs[accountID] + return excluded + } + + // 获取模型路由配置(仅 anthropic 平台有效) + var routingIDs []int64 + if sc.group != nil && sc.requestedModel != "" && sc.group.Platform == PlatformAnthropic { + routingIDs = sc.group.GetRoutingAccountIDs(sc.requestedModel) + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", + sc.group.ID, sc.requestedModel, sc.group.ModelRoutingEnabled, len(sc.group.ModelRouting), routingIDs, shortSessionHash(sc.sessionHash), sc.stickyAccountID) + if len(routingIDs) == 0 && sc.group.ModelRoutingEnabled && len(sc.group.ModelRouting) > 0 { + keys := make([]string, 0, len(sc.group.ModelRouting)) + for k := range sc.group.ModelRouting { + keys = append(keys, k) + } + sort.Strings(keys) + const maxKeys = 20 + if len(keys) > maxKeys { + keys = keys[:maxKeys] + } + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", sc.group.ID, sc.requestedModel, keys) + } + } + } + + pool := &accountPool{ + accounts: accounts, + byID: byID, + useMixed: useMixed, + platform: platform, + preferOAuth: preferOAuth, + routingIDs: routingIDs, + isExcluded: isExcluded, + } + return ctx, pool, nil +} + +// tryModelRoutingSelection 模型路由层选择(Layer 1)。 +// 当分组配置了模型路由规则时,优先从指定账号列表中选择。 +// 返回 nil 表示本层未命中,继续下一层。 +func (s *GatewayService) tryModelRoutingSelection(ctx context.Context, sc *schedulingContext, pool *accountPool) (*AccountSelectionResult, error) { + if len(pool.routingIDs) == 0 || s.concurrencyService == nil { + return nil, nil + } + + // 过滤出路由列表中可调度的账号(保留逐项计数器用于调试诊断) + var routingCandidates []*Account + var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int + var modelScopeSkippedIDs []int64 + + for _, routingID := range pool.routingIDs { + if pool.isExcluded(routingID) { + filteredExcluded++ + continue + } + account, ok := pool.byID[routingID] + if !ok { + filteredMissing++ + continue + } + if !s.isAccountSchedulableForSelection(account) { + filteredUnsched++ + continue + } + if !s.isAccountAllowedForPlatform(account, pool.platform, pool.useMixed) { + filteredPlatform++ + continue + } + if sc.requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, sc.requestedModel) { + filteredModelMapping++ + continue + } + if !s.isAccountSchedulableForModelSelection(ctx, account, sc.requestedModel) { + filteredModelScope++ + modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) + continue + } + if !s.isAccountSchedulableForQuota(account) { + continue + } + if !s.isAccountSchedulableForWindowCost(ctx, account, false) { + filteredWindowCost++ + continue + } + if !s.isAccountSchedulableForRPM(ctx, account, false) { + continue + } + routingCandidates = append(routingCandidates, account) + } + + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", + derefGroupID(sc.groupID), sc.requestedModel, len(pool.routingIDs), len(routingCandidates), + filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) + if len(modelScopeSkippedIDs) > 0 { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", + derefGroupID(sc.groupID), sc.requestedModel, modelScopeSkippedIDs) + } + } + + if len(routingCandidates) == 0 { + logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", sc.requestedModel) + return nil, nil + } + + // Layer 1.5: 在路由范围内检查粘性会话 + if result, err := s.tryStickyWithinRouting(ctx, sc, pool, routingCandidates); result != nil || err != nil { + return result, err + } + + // 批量获取负载信息 + routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates)) + for _, acc := range routingCandidates { + routingLoads = append(routingLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) + + // 筛选负载率 < 100% 的账号 + var routingAvailable []accountWithLoad + for _, acc := range routingCandidates { + loadInfo := routingLoadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo}) + } + } + + if len(routingAvailable) == 0 { + logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", sc.requestedModel) + return nil, nil + } + + sortAccountsWithLoadByPriority(routingAvailable) + + // 尝试获取槽位 + for _, item := range routingAvailable { + result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) + if err == nil && result.Acquired { + if !s.checkAndRegisterSession(ctx, item.account, sc.sessionHash) { + result.ReleaseFunc() + continue + } + if sc.sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(sc.groupID), sc.sessionHash, item.account.ID, stickySessionTTL) + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(sc.groupID), sc.requestedModel, shortSessionHash(sc.sessionHash), item.account.ID) + } + return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) + } + } + + // 所有路由账号槽位满,返回等待计划(选择第一个通过会话限制的) + for _, item := range routingAvailable { + if !s.checkAndRegisterSession(ctx, item.account, sc.sessionHash) { + continue + } + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(sc.groupID), sc.requestedModel, shortSessionHash(sc.sessionHash), item.account.ID) + } + return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: sc.cfg.StickySessionWaitTimeout, + MaxWaiting: sc.cfg.StickySessionMaxWaiting, + }) + } + // 所有路由账号会话限制都已满,继续到 Layer 2 + return nil, nil +} + +// tryStickyWithinRouting 在模型路由范围内尝试粘性会话复用。 +// 当粘性账号恰好在路由列表中时,优先复用以保持 prompt cache 命中率。 +func (s *GatewayService) tryStickyWithinRouting(ctx context.Context, sc *schedulingContext, pool *accountPool, routingCandidates []*Account) (*AccountSelectionResult, error) { + if sc.sessionHash == "" || sc.stickyAccountID <= 0 { + return nil, nil + } + + slog.Debug("sticky.layer1_5_checking", + "sticky_account_id", sc.stickyAccountID, + "in_routing_list", containsInt64(pool.routingIDs, sc.stickyAccountID), + "is_excluded", pool.isExcluded(sc.stickyAccountID), + "in_account_map", func() bool { _, ok := pool.byID[sc.stickyAccountID]; return ok }(), + "session", shortSessionHash(sc.sessionHash), + ) + + if !containsInt64(pool.routingIDs, sc.stickyAccountID) || pool.isExcluded(sc.stickyAccountID) { + return nil, nil + } + + stickyAccount, ok := pool.byID[sc.stickyAccountID] + if !ok { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(sc.groupID), sc.sessionHash) + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", + sc.stickyAccountID, shortSessionHash(sc.sessionHash)) + return nil, nil + } + + var missReason string + + // 粘性路径使用完整资格检查(isSticky=true 放宽 WindowCost/RPM 阈值) + gatePass := s.isAccountEligibleForScheduling(ctx, stickyAccount, eligibilityOpts{ + platform: pool.platform, + useMixed: pool.useMixed, + requestedModel: sc.requestedModel, + isSticky: true, + }) + rpmPass := gatePass // RPM 已包含在 isAccountEligibleForScheduling 中 + + if rpmPass { + result, err := s.tryAcquireAccountSlot(ctx, sc.stickyAccountID, stickyAccount.Concurrency) + if err == nil && result.Acquired { + if !s.checkAndRegisterSession(ctx, stickyAccount, sc.sessionHash) { + result.ReleaseFunc() + missReason = "session_limit" + } else { + slog.Debug("sticky.layer1_5_hit", + "account_id", sc.stickyAccountID, + "session", shortSessionHash(sc.sessionHash), + "result", "slot_acquired", + ) + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(sc.groupID), sc.requestedModel, shortSessionHash(sc.sessionHash), sc.stickyAccountID) + } + return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) + } + } + + if missReason == "" { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, sc.stickyAccountID) + if waitingCount < sc.cfg.StickySessionMaxWaiting { + if !s.checkAndRegisterSession(ctx, stickyAccount, sc.sessionHash) { + missReason = "session_limit" + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: sc.stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: sc.cfg.StickySessionWaitTimeout, + MaxWaiting: sc.cfg.StickySessionMaxWaiting, + }, + }, nil + } + } else { + missReason = "wait_queue_full" + } + } + } else if !gatePass { + missReason = "gate_check" + } else { + missReason = "rpm_red" + } + + if missReason != "" { + baseRPM := stickyAccount.GetBaseRPM() + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { + currentRPM = count + } + logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", + missReason, sc.stickyAccountID, shortSessionHash(sc.sessionHash), currentRPM, baseRPM) + } + + return nil, nil +} + +// tryStickySessionSelection 粘性会话层选择(Layer 1.5)。 +// 仅在无模型路由配置时生效,尝试复用上次绑定的账号。 +// 返回 nil 表示本层未命中。 +func (s *GatewayService) tryStickySessionSelection(ctx context.Context, sc *schedulingContext, pool *accountPool) (*AccountSelectionResult, error) { + // 仅在无模型路由配置时执行 + if len(pool.routingIDs) > 0 { + return nil, nil + } + + if sc.sessionHash == "" || sc.stickyAccountID <= 0 || pool.isExcluded(sc.stickyAccountID) { + if sc.sessionHash != "" { + slog.Debug("sticky.layer1_5_no_routing_skip", + "sticky_account_id", sc.stickyAccountID, + "is_excluded", func() bool { return sc.stickyAccountID > 0 && pool.isExcluded(sc.stickyAccountID) }(), + "session", shortSessionHash(sc.sessionHash), + "reason", func() string { + if sc.stickyAccountID == 0 { + return "no_sticky_binding" + } + return "sticky_account_excluded" + }(), + ) + } + return nil, nil + } + + account, ok := pool.byID[sc.stickyAccountID] + if !ok { + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", sc.stickyAccountID, + "reason", "account_not_in_map", + "session", shortSessionHash(sc.sessionHash), + ) + return nil, nil + } + + // 检查账号是否需要清理粘性会话绑定 + clearSticky := shouldClearStickySession(account, sc.requestedModel) + if clearSticky { + slog.Debug("sticky.layer1_5_no_routing_clear", + "account_id", sc.stickyAccountID, + "reason", "should_clear_sticky_session", + "session", shortSessionHash(sc.sessionHash), + ) + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(sc.groupID), sc.sessionHash) + } + + // 完整门控检查 + platformOK := s.isAccountAllowedForPlatform(account, pool.platform, pool.useMixed) + modelSupported := sc.requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, sc.requestedModel) + modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, sc.requestedModel) + quotaOK := s.isAccountSchedulableForQuota(account) + windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true) + rpmOK := s.isAccountSchedulableForRPM(ctx, account, true) + schedulable := s.isAccountSchedulableForSelection(account) + + slog.Debug("sticky.layer1_5_no_routing_checks", + "account_id", sc.stickyAccountID, + "session", shortSessionHash(sc.sessionHash), + "clear_sticky", clearSticky, + "schedulable", schedulable, + "platform_ok", platformOK, + "model_supported", modelSupported, + "model_schedulable", modelSchedulable, + "quota_ok", quotaOK, + "window_cost_ok", windowCostOK, + "rpm_ok", rpmOK, + ) + + if clearSticky || !platformOK || !modelSupported || !modelSchedulable || !quotaOK || !windowCostOK || !rpmOK || !schedulable { + if !clearSticky { + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", sc.stickyAccountID, + "reason", "gate_check_failed", + "session", shortSessionHash(sc.sessionHash), + ) + } + return nil, nil + } + + // 尝试获取槽位 + result, err := s.tryAcquireAccountSlot(ctx, sc.stickyAccountID, account.Concurrency) + if err == nil && result.Acquired { + if !s.checkAndRegisterSession(ctx, account, sc.sessionHash) { + result.ReleaseFunc() + slog.Debug("sticky.layer1_5_no_routing_miss", + "account_id", sc.stickyAccountID, + "reason", "session_limit", + "session", shortSessionHash(sc.sessionHash), + ) + } else { + slog.Debug("sticky.layer1_5_no_routing_hit", + "account_id", sc.stickyAccountID, + "session", shortSessionHash(sc.sessionHash), + "result", "slot_acquired", + ) + if s.cache != nil { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(sc.groupID), sc.sessionHash, stickySessionTTL) + } + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + } + } else { + slog.Debug("sticky.layer1_5_no_routing_slot_busy", + "account_id", sc.stickyAccountID, + "session", shortSessionHash(sc.sessionHash), + ) + } + + // 槽位未获取,尝试等待计划 + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, sc.stickyAccountID) + if waitingCount < sc.cfg.StickySessionMaxWaiting { + if !s.checkAndRegisterSession(ctx, account, sc.sessionHash) { + return nil, nil + } + slog.Debug("sticky.layer1_5_no_routing_hit", + "account_id", sc.stickyAccountID, + "session", shortSessionHash(sc.sessionHash), + "result", "wait_plan", + ) + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: sc.stickyAccountID, + MaxConcurrency: account.Concurrency, + Timeout: sc.cfg.StickySessionWaitTimeout, + MaxWaiting: sc.cfg.StickySessionMaxWaiting, + }) + } + + return nil, nil +} + +// selectByLoadBalance 负载均衡层选择(Layer 2)。 +// 从所有可调度账号中按负载率选择最优账号。 +// 返回 nil 表示所有账号满载,继续到 Layer 3 排队。 +func (s *GatewayService) selectByLoadBalance(ctx context.Context, sc *schedulingContext, pool *accountPool) (*AccountSelectionResult, error) { + slog.Debug("sticky.layer2_fallback", + "session", shortSessionHash(sc.sessionHash), + "sticky_account_id", sc.stickyAccountID, + "reason", "sticky_not_used_falling_back_to_load_balance", + "total_accounts", len(pool.accounts), + ) + + // 构建候选列表 + opts := eligibilityOpts{ + platform: pool.platform, + useMixed: pool.useMixed, + requestedModel: sc.requestedModel, + isSticky: false, + } + candidates := make([]*Account, 0, len(pool.accounts)) + for i := range pool.accounts { + acc := &pool.accounts[i] + if pool.isExcluded(acc.ID) { + continue + } + if !s.isAccountEligibleForScheduling(ctx, acc, opts) { + continue + } + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, nil + } + + // 批量获取负载信息 + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + // 负载查询失败时降级到传统排序 + if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, sc.groupID, sc.sessionHash, pool.preferOAuth); legacyErr != nil { + return nil, legacyErr + } else if ok { + return result, nil + } + return nil, nil + } + + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{account: acc, loadInfo: loadInfo}) + } + } + + // 分层过滤选择:优先级 → 负载率 → LRU + for len(available) > 0 { + filtered := filterByMinPriority(available) + filtered = filterByMinLoadRate(filtered) + selected := selectByLRU(filtered, pool.preferOAuth) + if selected == nil { + break + } + + result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) + if err == nil && result.Acquired { + if !s.checkAndRegisterSession(ctx, selected.account, sc.sessionHash) { + result.ReleaseFunc() + } else { + if sc.sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(sc.groupID), sc.sessionHash, selected.account.ID, stickySessionTTL) + } + return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) + } + } + + // 移除已尝试的账号 + selectedID := selected.account.ID + newAvailable := make([]accountWithLoad, 0, len(available)-1) + for _, acc := range available { + if acc.account.ID != selectedID { + newAvailable = append(newAvailable, acc) + } + } + available = newAvailable + } + + return nil, nil +} + +// selectWithQueueFallback 兜底排队选择(Layer 3)。 +// 当所有账号负载率 >= 100% 时,选择一个账号进入等待队列。 +func (s *GatewayService) selectWithQueueFallback(ctx context.Context, sc *schedulingContext, pool *accountPool) (*AccountSelectionResult, error) { + candidates := make([]*Account, 0, len(pool.accounts)) + for i := range pool.accounts { + acc := &pool.accounts[i] + if pool.isExcluded(acc.ID) { + continue + } + if !s.isAccountSchedulableForSelection(acc) { + continue + } + candidates = append(candidates, acc) + } + + s.sortCandidatesForFallback(candidates, pool.preferOAuth, sc.cfg.FallbackSelectionMode) + for _, acc := range candidates { + if !s.checkAndRegisterSession(ctx, acc, sc.sessionHash) { + continue + } + return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: sc.cfg.FallbackWaitTimeout, + MaxWaiting: sc.cfg.FallbackMaxWaiting, + }) + } + return nil, ErrNoAvailableAccounts +} diff --git a/backend/internal/service/gateway_beta_policy.go b/backend/internal/service/gateway_beta_policy.go new file mode 100644 index 00000000000..85f8e7b070d --- /dev/null +++ b/backend/internal/service/gateway_beta_policy.go @@ -0,0 +1,388 @@ +package service + +import ( + "context" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/tidwall/gjson" + + "github.com/gin-gonic/gin" +) + +// getBetaHeader 处理anthropic-beta header +// 对于OAuth账号,需要确保包含oauth-2025-04-20 +func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { + // 如果客户端传了anthropic-beta + if clientBetaHeader != "" { + // 已包含oauth beta则直接返回 + if strings.Contains(clientBetaHeader, claude.BetaOAuth) { + return clientBetaHeader + } + + // 需要添加oauth beta + parts := strings.Split(clientBetaHeader, ",") + for i, p := range parts { + parts[i] = strings.TrimSpace(p) + } + + // 在claude-code-20250219后面插入oauth beta + claudeCodeIdx := -1 + for i, p := range parts { + if p == claude.BetaClaudeCode { + claudeCodeIdx = i + break + } + } + + if claudeCodeIdx >= 0 { + // 在claude-code后面插入 + newParts := make([]string, 0, len(parts)+1) + newParts = append(newParts, parts[:claudeCodeIdx+1]...) + newParts = append(newParts, claude.BetaOAuth) + newParts = append(newParts, parts[claudeCodeIdx+1:]...) + return strings.Join(newParts, ",") + } + + // 没有claude-code,放在第一位 + return claude.BetaOAuth + "," + clientBetaHeader + } + + // 客户端没传,根据模型生成 + // haiku 模型不需要 claude-code beta + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.HaikuBetaHeader + } + + return claude.DefaultBetaHeader +} + +func requestNeedsBetaFeatures(body []byte) bool { + tools := gjson.GetBytes(body, "tools") + if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { + return true + } + thinkingType := gjson.GetBytes(body, "thinking.type").String() + if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") { + return true + } + return false +} + +func defaultAPIKeyBetaHeader(body []byte) string { + modelID := gjson.GetBytes(body, "model").String() + if strings.Contains(strings.ToLower(modelID), "haiku") { + return claude.APIKeyHaikuBetaHeader + } + return claude.APIKeyBetaHeader +} + +func applyClaudeOAuthHeaderDefaults(req *http.Request) { + if req == nil { + return + } + if getHeaderRaw(req.Header, "Accept") == "" { + setHeaderRaw(req.Header, "Accept", "application/json") + } + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + if getHeaderRaw(req.Header, key) == "" { + setHeaderRaw(req.Header, resolveWireCasing(key), value) + } + } +} + +func mergeAnthropicBeta(required []string, incoming string) string { + seen := make(map[string]struct{}, len(required)+8) + out := make([]string, 0, len(required)+8) + + add := func(v string) { + v = strings.TrimSpace(v) + if v == "" { + return + } + if _, ok := seen[v]; ok { + return + } + seen[v] = struct{}{} + out = append(out, v) + } + + for _, r := range required { + add(r) + } + for _, p := range strings.Split(incoming, ",") { + add(p) + } + return strings.Join(out, ",") +} + +func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string { + merged := mergeAnthropicBeta(required, incoming) + if merged == "" || len(drop) == 0 { + return merged + } + out := make([]string, 0, 8) + for _, p := range strings.Split(merged, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + return strings.Join(out, ",") +} + +// stripBetaTokens removes the given beta tokens from a comma-separated header value. +func stripBetaTokens(header string, tokens []string) string { + if header == "" || len(tokens) == 0 { + return header + } + return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens)) +} + +func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { + if header == "" || len(drop) == 0 { + return header + } + parts := strings.Split(header, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if _, ok := drop[p]; ok { + continue + } + out = append(out, p) + } + if len(out) == len(parts) { + return header // no change, avoid allocation + } + return strings.Join(out, ",") +} + +// BetaBlockedError indicates a request was blocked by a beta policy rule. +type BetaBlockedError struct { + Message string +} + +// betaPolicyResult holds the evaluated result of beta policy rules for a single request. +type betaPolicyResult struct { + blockErr *BetaBlockedError // non-nil if a block rule matched + filterSet map[string]struct{} // tokens to filter (may be nil) +} + +// evaluateBetaPolicy loads settings once and evaluates all rules against the given request. +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult { + if s.settingService == nil { + return betaPolicyResult{} + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return betaPolicyResult{} + } + isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() + var result betaPolicyResult + for _, rule := range settings.Rules { + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + switch effectiveAction { + case BetaPolicyActionBlock: + if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { + msg := effectiveErrMsg + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + result.blockErr = &BetaBlockedError{Message: msg} + } + case BetaPolicyActionFilter: + if result.filterSet == nil { + result.filterSet = make(map[string]struct{}) + } + result.filterSet[rule.BetaToken] = struct{}{} + } + } + return result +} + +// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request. +const betaPolicyFilterSetKey = "betaPolicyFilterSet" + +// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available. +// In the /v1/messages path, Forward() evaluates the policy first and caches the result; +// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this +// evaluates on demand (one DB call). +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} { + if c != nil { + if v, ok := c.Get(betaPolicyFilterSetKey); ok { + if fs, ok := v.(map[string]struct{}); ok { + return fs + } + } + } + return s.evaluateBetaPolicy(ctx, "", account, model).filterSet +} + +// betaPolicyScopeMatches checks whether a rule's scope matches the current account type. +func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { + switch scope { + case BetaPolicyScopeAll: + return true + case BetaPolicyScopeOAuth: + return isOAuth + case BetaPolicyScopeAPIKey: + return !isOAuth && !isBedrock + case BetaPolicyScopeBedrock: + return isBedrock + default: + return true // unknown scope → match all (fail-open) + } +} + +// matchModelWhitelist checks if a model matches any pattern in the whitelist. +// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching. +func matchModelWhitelist(model string, whitelist []string) bool { + for _, pattern := range whitelist { + if matchModelPattern(pattern, model) { + return true + } + } + return false +} + +// resolveRuleAction determines the effective action and error message for a rule given the request model. +// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally. +// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others. +func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) { + if len(rule.ModelWhitelist) == 0 { + return rule.Action, rule.ErrorMessage + } + if matchModelWhitelist(model, rule.ModelWhitelist) { + return rule.Action, rule.ErrorMessage + } + if rule.FallbackAction != "" { + return rule.FallbackAction, rule.FallbackErrorMessage + } + return BetaPolicyActionPass, "" // default fallback: pass (fail-open) +} + +// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. +func droppedBetaSet(extra ...string) map[string]struct{} { + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m +} + +// containsBetaToken checks if a comma-separated header value contains the given token. +func containsBetaToken(header, token string) bool { + if header == "" || token == "" { + return false + } + for _, p := range strings.Split(header, ",") { + if strings.TrimSpace(p) == token { + return true + } + } + return false +} + +func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string { + if len(tokens) == 0 || len(filterSet) == 0 { + return tokens + } + kept := make([]string, 0, len(tokens)) + for _, token := range tokens { + if _, filtered := filterSet[token]; !filtered { + kept = append(kept, token) + } + } + return kept +} + +func (s *GatewayService) resolveBedrockBetaTokensForRequest( + ctx context.Context, + account *Account, + betaHeader string, + body []byte, + modelID string, +) ([]string, error) { + // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID) + if policy.blockErr != nil { + return nil, policy.blockErr + } + + // 2. 解析 header + body 自动注入 + Bedrock 转换/过滤 + betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) + + // 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。 + // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, + // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → + // 如果不做此检查,block 规则会被绕过。 + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil { + return nil, blockErr + } + + return filterBetaTokens(betaTokens, policy.filterSet), nil +} + +// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 +// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError { + if s.settingService == nil || len(tokens) == 0 { + return nil + } + settings, err := s.settingService.GetBetaPolicySettings(ctx) + if err != nil || settings == nil { + return nil + } + isOAuth := account.IsOAuth() + isBedrock := account.IsBedrock() + tokenSet := buildBetaTokenSet(tokens) + for _, rule := range settings.Rules { + effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) + if effectiveAction != BetaPolicyActionBlock { + continue + } + if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { + continue + } + if _, present := tokenSet[rule.BetaToken]; present { + msg := effectiveErrMsg + if msg == "" { + msg = "beta feature " + rule.BetaToken + " is not allowed" + } + return &BetaBlockedError{Message: msg} + } + } + return nil +} + +func buildBetaTokenSet(tokens []string) map[string]struct{} { + m := make(map[string]struct{}, len(tokens)) + for _, t := range tokens { + if t == "" { + continue + } + m[t] = struct{}{} + } + return m +} + +var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) diff --git a/backend/internal/service/gateway_beta_policy_extra_test.go b/backend/internal/service/gateway_beta_policy_extra_test.go new file mode 100644 index 00000000000..e6282e950ef --- /dev/null +++ b/backend/internal/service/gateway_beta_policy_extra_test.go @@ -0,0 +1,132 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBetaPolicyScopeMatches_All(t *testing.T) { + assert.True(t, betaPolicyScopeMatches("all", true, false)) + assert.True(t, betaPolicyScopeMatches("all", false, false)) + assert.True(t, betaPolicyScopeMatches("all", false, true)) +} + +func TestBetaPolicyScopeMatches_OAuth(t *testing.T) { + assert.True(t, betaPolicyScopeMatches("oauth", true, false)) + assert.False(t, betaPolicyScopeMatches("oauth", false, false)) + assert.False(t, betaPolicyScopeMatches("oauth", false, true)) +} + +func TestBetaPolicyScopeMatches_APIKey(t *testing.T) { + assert.True(t, betaPolicyScopeMatches("apikey", false, false)) + assert.False(t, betaPolicyScopeMatches("apikey", true, false)) + assert.False(t, betaPolicyScopeMatches("apikey", false, true)) +} + +func TestBetaPolicyScopeMatches_Bedrock(t *testing.T) { + assert.True(t, betaPolicyScopeMatches("bedrock", false, true)) + assert.False(t, betaPolicyScopeMatches("bedrock", true, false)) + assert.False(t, betaPolicyScopeMatches("bedrock", false, false)) +} + +func TestBetaPolicyScopeMatches_Unknown_FailOpen(t *testing.T) { + // Unknown scope defaults to match-all (fail-open) + assert.True(t, betaPolicyScopeMatches("unknown", true, true)) + assert.True(t, betaPolicyScopeMatches("", true, false)) +} + +func TestMatchModelWhitelist_ExactMatch(t *testing.T) { + assert.True(t, matchModelWhitelist("claude-sonnet-4-5", []string{"claude-sonnet-4-5"})) + assert.False(t, matchModelWhitelist("claude-opus-4", []string{"claude-sonnet-4-5"})) +} + +func TestMatchModelWhitelist_EmptyWhitelist(t *testing.T) { + assert.False(t, matchModelWhitelist("claude-sonnet-4-5", nil)) + assert.False(t, matchModelWhitelist("claude-sonnet-4-5", []string{})) +} + +func TestMatchModelWhitelist_WildcardPrefix(t *testing.T) { + assert.True(t, matchModelWhitelist("claude-sonnet-4-5-20250514", []string{"claude-sonnet-4-5*"})) + assert.False(t, matchModelWhitelist("claude-opus-4", []string{"claude-sonnet*"})) +} + +func TestResolveRuleAction_NoWhitelist(t *testing.T) { + rule := BetaPolicyRule{Action: "block", ErrorMessage: "blocked"} + action, msg := resolveRuleAction(rule, "any-model") + assert.Equal(t, "block", action) + assert.Equal(t, "blocked", msg) +} + +func TestResolveRuleAction_ModelInWhitelist(t *testing.T) { + rule := BetaPolicyRule{ + Action: "block", + ErrorMessage: "blocked", + ModelWhitelist: []string{"claude-sonnet-4-5"}, + } + action, _ := resolveRuleAction(rule, "claude-sonnet-4-5") + assert.Equal(t, "block", action) +} + +func TestResolveRuleAction_ModelNotInWhitelist(t *testing.T) { + rule := BetaPolicyRule{ + Action: "block", + ErrorMessage: "blocked", + ModelWhitelist: []string{"claude-sonnet-4-5"}, + FallbackAction: "filter", + } + action, _ := resolveRuleAction(rule, "claude-opus-4") + assert.Equal(t, "filter", action) +} + +func TestFilterBetaTokens_RemovesFiltered(t *testing.T) { + tokens := []string{"token-a", "token-b", "token-c"} + filterSet := map[string]struct{}{"token-b": {}} + result := filterBetaTokens(tokens, filterSet) + assert.Equal(t, []string{"token-a", "token-c"}, result) +} + +func TestFilterBetaTokens_EmptyFilterSet(t *testing.T) { + tokens := []string{"token-a", "token-b"} + result := filterBetaTokens(tokens, nil) + assert.Equal(t, []string{"token-a", "token-b"}, result) +} + +func TestFilterBetaTokens_AllFiltered(t *testing.T) { + tokens := []string{"token-a", "token-b"} + filterSet := map[string]struct{}{"token-a": {}, "token-b": {}} + result := filterBetaTokens(tokens, filterSet) + assert.Empty(t, result) +} + +func TestRequestNeedsBetaFeatures_WithTools(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","tools":[{"name":"get_weather"}]}`) + assert.True(t, requestNeedsBetaFeatures(body)) +} + +func TestRequestNeedsBetaFeatures_WithThinking(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"}}`) + assert.True(t, requestNeedsBetaFeatures(body)) +} + +func TestRequestNeedsBetaFeatures_AdaptiveThinking(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"adaptive"}}`) + assert.True(t, requestNeedsBetaFeatures(body)) +} + +func TestRequestNeedsBetaFeatures_NoFeatures(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","messages":[]}`) + assert.False(t, requestNeedsBetaFeatures(body)) +} + +func TestRequestNeedsBetaFeatures_EmptyTools(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","tools":[]}`) + assert.False(t, requestNeedsBetaFeatures(body)) +} + +func TestRequestNeedsBetaFeatures_DisabledThinking(t *testing.T) { + body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"disabled"}}`) + assert.False(t, requestNeedsBetaFeatures(body)) +} diff --git a/backend/internal/service/gateway_billing.go b/backend/internal/service/gateway_billing.go new file mode 100644 index 00000000000..f12c056b2cb --- /dev/null +++ b/backend/internal/service/gateway_billing.go @@ -0,0 +1,838 @@ +package service + +import ( + "context" + "log/slog" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// RecordUsageInput 记录使用量的输入参数 +type RecordUsageInput struct { + Result *ForwardResult + ParsedRequest *ParsedRequest + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 + + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) +} + +// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage +type APIKeyQuotaUpdater interface { + UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error + UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error +} + +type apiKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) +} + +type usageLogBestEffortWriter interface { + CreateBestEffort(ctx context.Context, log *UsageLog) error +} + +// postUsageBillingParams 统一扣费所需的参数 +type postUsageBillingParams struct { + Cost *CostBreakdown + User *User + APIKey *APIKey + Account *Account + Subscription *UserSubscription + RequestPayloadHash string + IsSubscriptionBill bool + AccountRateMultiplier float64 + APIKeyService APIKeyQuotaUpdater +} + +func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool { + return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateRateLimits() bool { + return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil +} + +func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { + return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() +} + +// postUsageBilling is the legacy fallback billing path used when the unified +// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply +// for atomic billing. This path only runs in tests or degraded mode. +func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + cost := p.Cost + + if p.IsSubscriptionBill { + // Subscription usage tracked by ActualCost so group rate multiplier + // consumes the quota at the expected speed. + if cost.ActualCost > 0 { + if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil { + slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) + } + } + } else { + if cost.ActualCost > 0 { + if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { + slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) + } + } + } + + if p.shouldDeductAPIKeyQuota() { + if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + if p.shouldUpdateRateLimits() { + if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { + slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) + } + } + + if p.shouldUpdateAccountQuota() { + accountCost := cost.TotalCost * p.AccountRateMultiplier + if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { + slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) + } + } + + // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing + // cache updates. The legacy path does DB writes directly; the finalize path + // does cache queue + notifications. Notifications are dispatched separately + // by the caller after recording the usage log. +} + +func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { + return requestID + } + return "generated:" + generateRequestID() +} + +func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { + if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" { + return payloadHash + } + if ctx != nil { + if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { + return "client:" + strings.TrimSpace(clientRequestID) + } + if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { + return "local:" + strings.TrimSpace(requestID) + } + } + return "" +} + +func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand { + if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil { + return nil + } + + cmd := &UsageBillingCommand{ + RequestID: requestID, + APIKeyID: p.APIKey.ID, + UserID: p.User.ID, + AccountID: p.Account.ID, + AccountType: p.Account.Type, + RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash), + } + if usageLog != nil { + cmd.Model = usageLog.Model + cmd.BillingType = usageLog.BillingType + cmd.InputTokens = usageLog.InputTokens + cmd.OutputTokens = usageLog.OutputTokens + cmd.CacheCreationTokens = usageLog.CacheCreationTokens + cmd.CacheReadTokens = usageLog.CacheReadTokens + cmd.ImageCount = usageLog.ImageCount + if usageLog.ServiceTier != nil { + cmd.ServiceTier = *usageLog.ServiceTier + } + if usageLog.ReasoningEffort != nil { + cmd.ReasoningEffort = *usageLog.ReasoningEffort + } + if usageLog.SubscriptionID != nil { + cmd.SubscriptionID = usageLog.SubscriptionID + } + } + + // Record subscription / balance cost using ActualCost so the group (and any + // user-specific) rate multiplier consumes subscription quota at the expected + // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards + // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0). + if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { + cmd.SubscriptionID = &p.Subscription.ID + cmd.SubscriptionCost = p.Cost.ActualCost + } else if p.Cost.ActualCost > 0 { + cmd.BalanceCost = p.Cost.ActualCost + } + + if p.shouldDeductAPIKeyQuota() { + cmd.APIKeyQuotaCost = p.Cost.ActualCost + } + if p.shouldUpdateRateLimits() { + cmd.APIKeyRateLimitCost = p.Cost.ActualCost + } + if p.shouldUpdateAccountQuota() { + cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier + } + + cmd.Normalize() + return cmd +} + +func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) { + if p == nil || deps == nil { + return false, nil + } + + cmd := buildUsageBillingCommand(requestID, usageLog, p) + if cmd == nil || cmd.RequestID == "" || repo == nil { + postUsageBilling(ctx, p, deps) + return true, nil + } + + billingCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + result, err := repo.Apply(billingCtx, cmd) + if err != nil { + return false, err + } + + if result == nil || !result.Applied { + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + return false, nil + } + + if result.APIKeyQuotaExhausted { + if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" { + invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key) + } + } + + finalizePostUsageBilling(p, deps, result) + return true, nil +} + +func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { + if p == nil || p.Cost == nil || deps == nil { + return + } + + if p.IsSubscriptionBill { + if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { + deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost) + } + } else if p.Cost.ActualCost > 0 && p.User != nil { + deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) + } + + if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() { + deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost) + } + + deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) + + // Notification checks run async — all parameters are already captured, + // no dependency on the request context or upstream connection. + go notifyBalanceLow(p, deps, result) + go notifyAccountQuota(p, deps, result) +} + +// notifyBalanceLow sends balance low notification after deduction. +// When result.NewBalance is available (from DB transaction RETURNING), it is used directly +// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races. +func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in notifyBalanceLow", "recover", r) + } + }() + if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil { + slog.Debug("notifyBalanceLow: skipped", + "is_subscription", p.IsSubscriptionBill, + "actual_cost", p.Cost.ActualCost, + "user_nil", p.User == nil, + "service_nil", deps.balanceNotifyService == nil, + ) + return + } + + oldBalance := resolveOldBalance(p, result) + slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction", + "user_id", p.User.ID, + "old_balance", oldBalance, + "cost", p.Cost.ActualCost, + "notify_enabled", p.User.BalanceNotifyEnabled, + "threshold", p.User.BalanceNotifyThreshold, + "result_has_new_balance", result != nil && result.NewBalance != nil, + ) + deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost) +} + +// resolveOldBalance returns the pre-deduction balance. +// Prefers the DB transaction result (newBalance + cost) over snapshot. +func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 { + if result != nil && result.NewBalance != nil { + return *result.NewBalance + p.Cost.ActualCost + } + // Legacy fallback: snapshot balance from request context + return p.User.Balance +} + +// notifyAccountQuota sends account quota threshold notification after increment. +// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly +// to avoid a separate DB read that may see stale or concurrently-modified data. +func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { + defer func() { + if r := recover(); r != nil { + slog.Error("panic in notifyAccountQuota", "recover", r) + } + }() + if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil { + slog.Debug("notifyAccountQuota: skipped", + "total_cost", p.Cost.TotalCost, + "account_nil", p.Account == nil, + "is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(), + "service_nil", deps.balanceNotifyService == nil, + ) + return + } + accountCost := p.Cost.TotalCost * p.AccountRateMultiplier + var quotaState *AccountQuotaState + if result != nil { + quotaState = result.QuotaState + } + slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement", + "account_id", p.Account.ID, + "account_cost", accountCost, + "has_quota_state", quotaState != nil, + ) + deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState) +} + +func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + return context.WithTimeout(base, postUsageBillingTimeout) +} + +func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.Background(), func() {} + } + if !stream { + return ctx, func() {} + } + return context.WithoutCancel(ctx), func() {} +} + +func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + return context.Background(), func() {} + } + return context.WithoutCancel(ctx), func() {} +} + +// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) +type billingDeps struct { + accountRepo AccountRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCacheService *BillingCacheService + deferredService *DeferredService + balanceNotifyService *BalanceNotifyService +} + +func (s *GatewayService) billingDeps() *billingDeps { + return &billingDeps{ + accountRepo: s.accountRepo, + userRepo: s.userRepo, + userSubRepo: s.userSubRepo, + billingCacheService: s.billingCacheService, + deferredService: s.deferredService, + balanceNotifyService: s.balanceNotifyService, + } +} + +func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) { + if repo == nil || usageLog == nil { + return + } + usageCtx, cancel := detachedBillingContext(ctx) + defer cancel() + + if writer, ok := repo.(usageLogBestEffortWriter); ok { + if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + if IsUsageLogCreateDropped(err) { + return + } + if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { + logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) + } + } + return + } + + if _, err := repo.Create(usageCtx, usageLog); err != nil { + logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + } +} + +// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 +type recordUsageOpts struct { + // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) + ParsedRequest *ParsedRequest + + // EnableClaudePath 启用 Claude 路径特有逻辑: + // - Claude Max 缓存计费策略 + EnableClaudePath bool + + // 长上下文计费(仅 Gemini 路径需要) + LongContextThreshold int + LongContextMultiplier float64 +} + +// RecordUsage 记录使用量并扣费(或更新订阅用量) +func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + EnableClaudePath: true, + }) +} + +// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) +type RecordUsageLongContextInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription // 可选:订阅信息 + InboundEndpoint string // 入站端点(客户端请求路径) + UpstreamEndpoint string // 上游端点(标准化后的上游路径) + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 + LongContextThreshold int // 长上下文阈值(如 200000) + LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) + ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) + APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) + + ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) +} + +// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) +func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { + return s.recordUsageCore(ctx, &recordUsageCoreInput{ + Result: input.Result, + APIKey: input.APIKey, + User: input.User, + Account: input.Account, + Subscription: input.Subscription, + InboundEndpoint: input.InboundEndpoint, + UpstreamEndpoint: input.UpstreamEndpoint, + UserAgent: input.UserAgent, + IPAddress: input.IPAddress, + RequestPayloadHash: input.RequestPayloadHash, + ForceCacheBilling: input.ForceCacheBilling, + APIKeyService: input.APIKeyService, + ChannelUsageFields: input.ChannelUsageFields, + }, &recordUsageOpts{ + LongContextThreshold: input.LongContextThreshold, + LongContextMultiplier: input.LongContextMultiplier, + }) +} + +// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。 +type recordUsageCoreInput struct { + Result *ForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + InboundEndpoint string + UpstreamEndpoint string + UserAgent string + IPAddress string + RequestPayloadHash string + ForceCacheBilling bool + APIKeyService APIKeyQuotaUpdater + ChannelUsageFields +} + +// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 +// opts 中的字段控制两者之间的差异行为: +// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 +// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext +func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { + result := input.Result + apiKey := input.APIKey + user := input.User + account := input.Account + subscription := input.Subscription + + // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens + // 用于粘性会话切换时的特殊计费处理 + if input.ForceCacheBilling && result.Usage.InputTokens > 0 { + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + result.Usage.InputTokens, account.ID) + result.Usage.CacheReadInputTokens += result.Usage.InputTokens + result.Usage.InputTokens = 0 + } + + // Cache TTL Override: 确保计费时 token 分类与账号设置一致。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 + cacheTTLOverridden := false + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { + applyCacheTTLOverride(&result.Usage, overrideTarget) + cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 + } + + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } + if apiKey.GroupID != nil && apiKey.Group != nil { + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) + } + imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier) + + // 确定计费模型 + billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { + billingModel = input.ChannelMappedModel + } + if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { + billingModel = input.OriginalModel + } + + // 确定 RequestedModel(渠道映射前的原始模型) + requestedModel := result.Model + if input.OriginalModel != "" { + requestedModel = input.OriginalModel + } + + // 计算费用 + cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts) + + // 判断计费方式:订阅模式 vs 余额模式 + isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() + billingType := BillingTypeBalance + if isSubscriptionBilling { + billingType = BillingTypeSubscription + } + + // 创建使用日志 + accountRateMultiplier := account.BillingRateMultiplier() + usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, + requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) + + // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) + if apiKey.GroupID != nil { + applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService, + account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model, + // Anthropic's input_tokens excludes cache_read and cache_creation (billed separately); + // OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason. + UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + }, + cost.TotalCost, + ) + } + + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + s.deferredService.ScheduleLastUsedUpdate(account.ID) + return nil + } + + requestID := usageLog.RequestID + _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ + Cost: cost, + User: user, + APIKey: apiKey, + Account: account, + Subscription: subscription, + RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), + IsSubscriptionBill: isSubscriptionBilling, + AccountRateMultiplier: accountRateMultiplier, + APIKeyService: input.APIKeyService, + }, s.billingDeps(), s.usageBillingRepo) + + if billingErr != nil { + return billingErr + } + writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") + + return nil +} + +// calculateRecordUsageCost 根据请求类型和选项计算费用。 +func (s *GatewayService) calculateRecordUsageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + imageMultiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + // 图片生成计费 + if result.ImageCount > 0 { + return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier) + } + + // Token 计费 + return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) +} + +// resolveChannelPricing 检查指定模型是否存在渠道级别定价。 +// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 +func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { + if s.resolver == nil || apiKey.Group == nil { + return nil + } + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + return resolved + } + return nil +} + +// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。 +func (s *GatewayService) calculateImageCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + gid := apiKey.Group.ID + cost, err := s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: result.ImageCount, + SizeTier: result.ImageSize, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} + } + return cost + } + + var groupConfig *ImagePriceConfig + if apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) +} + +// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。 +func (s *GatewayService) calculateTokenCost( + ctx context.Context, + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, + opts *recordUsageOpts, +) *CostBreakdown { + tokens := UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + } + + var cost *CostBreakdown + var err error + + // 优先尝试渠道定价 → CalculateCostUnified + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + gid := apiKey.Group.ID + cost, err = s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + } else if opts.LongContextThreshold > 0 { + // 长上下文双倍计费(如 Gemini 200K 阈值) + cost, err = s.billingService.CalculateCostWithLongContext( + billingModel, tokens, multiplier, + opts.LongContextThreshold, opts.LongContextMultiplier, + ) + } else { + cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) + } + if err != nil { + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) + return &CostBreakdown{ActualCost: 0} + } + return cost +} + +// buildRecordUsageLog 构建使用日志并设置计费模式。 +func (s *GatewayService) buildRecordUsageLog( + ctx context.Context, + input *recordUsageCoreInput, + result *ForwardResult, + apiKey *APIKey, + user *User, + account *Account, + subscription *UserSubscription, + requestedModel string, + multiplier float64, + imageMultiplier float64, + accountRateMultiplier float64, + billingType int8, + cacheTTLOverridden bool, + cost *CostBreakdown, + opts *recordUsageOpts, +) *UsageLog { + durationMs := int(result.Duration.Milliseconds()) + requestID := resolveUsageBillingRequestID(ctx, result.RequestID) + usageLog := &UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: requestID, + Model: result.Model, + RequestedModel: requestedModel, + UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), + ReasoningEffort: result.ReasoningEffort, + InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), + UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, + CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + RateMultiplier: multiplier, + AccountRateMultiplier: &accountRateMultiplier, + BillingType: billingType, + BillingMode: resolveBillingMode(result, cost), + Stream: result.Stream, + DurationMs: &durationMs, + FirstTokenMs: result.FirstTokenMs, + ImageCount: result.ImageCount, + ImageSize: optionalTrimmedStringPtr(result.ImageSize), + CacheTTLOverridden: cacheTTLOverridden, + ChannelID: optionalInt64Ptr(input.ChannelID), + ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), + UserAgent: optionalTrimmedStringPtr(input.UserAgent), + IPAddress: optionalTrimmedStringPtr(input.IPAddress), + GroupID: apiKey.GroupID, + SubscriptionID: optionalSubscriptionID(subscription), + CreatedAt: time.Now(), + } + if result.ImageCount > 0 { + usageLog.RateMultiplier = imageMultiplier + } + if cost != nil { + usageLog.InputCost = cost.InputCost + usageLog.OutputCost = cost.OutputCost + usageLog.ImageOutputCost = cost.ImageOutputCost + usageLog.CacheCreationCost = cost.CacheCreationCost + usageLog.CacheReadCost = cost.CacheReadCost + usageLog.TotalCost = cost.TotalCost + usageLog.ActualCost = cost.ActualCost + } + + return usageLog +} + +// resolveBillingMode 根据计费结果和请求类型确定计费模式。 +func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { + var mode string + switch { + case cost != nil && cost.BillingMode != "": + mode = cost.BillingMode + case result.ImageCount > 0: + mode = string(BillingModeImage) + default: + mode = string(BillingModeToken) + } + return &mode +} + +func optionalSubscriptionID(subscription *UserSubscription) *int64 { + if subscription != nil { + return &subscription.ID + } + return nil +} diff --git a/backend/internal/service/gateway_channel.go b/backend/internal/service/gateway_channel.go new file mode 100644 index 00000000000..222a9452c59 --- /dev/null +++ b/backend/internal/service/gateway_channel.go @@ -0,0 +1,116 @@ +package service + +import ( + "context" + "log/slog" +) + +// ResolveChannelMapping 委托渠道服务解析模型映射 +func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model} + } + return s.channelService.ResolveChannelMapping(ctx, groupID, model) +} + +// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用) +func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { + return ReplaceModelInBody(body, newModel) +} + +// IsModelRestricted 检查模型是否被渠道限制 +func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + if s.channelService == nil { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, model) +} + +// ResolveChannelMappingAndRestrict 解析渠道映射。 +// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。 +func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model}, false + } + return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) +} + +// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。 +// 供调度阶段预检查(requested / channel_mapped)。 +// upstream 需逐账号检查,此处返回 false。 +func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { + if groupID == nil || s.channelService == nil || requestedModel == "" { + return false + } + mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) + billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) + if billingModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) +} + +// billingModelForRestriction 根据计费基准确定限制检查使用的模型。 +// upstream 返回空(需逐账号检查)。 +func billingModelForRestriction(source, requestedModel, channelMappedModel string) string { + switch source { + case BillingModelSourceRequested: + return requestedModel + case BillingModelSourceUpstream: + return "" + case BillingModelSourceChannelMapped: + return channelMappedModel + default: + return channelMappedModel + } +} + +// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。 +// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。 +func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { + if s.channelService == nil { + return false + } + upstreamModel := resolveAccountUpstreamModel(account, requestedModel) + if upstreamModel == "" { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) +} + +// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。 +func resolveAccountUpstreamModel(account *Account, requestedModel string) string { + if account.Platform == PlatformAntigravity { + return mapAntigravityModel(account, requestedModel) + } + return account.GetMappedModel(requestedModel) +} + +// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。 +func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { + if groupID == nil || s.channelService == nil { + return false + } + ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) + if err != nil { + slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err) + return false + } + if ch == nil || !ch.RestrictModels { + return false + } + return ch.BillingModelSource == BillingModelSourceUpstream +} + +// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。 +// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用, +// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。 +func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool { + if groupID == nil { + return false + } + if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) { + return false + } + return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) +} diff --git a/backend/internal/service/gateway_claude_mimic.go b/backend/internal/service/gateway_claude_mimic.go new file mode 100644 index 00000000000..5b93b40a023 --- /dev/null +++ b/backend/internal/service/gateway_claude_mimic.go @@ -0,0 +1,561 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/google/uuid" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/gin-gonic/gin" +) + +func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string { + if req == nil { + return "" + } + + // Only log a minimal fingerprint to avoid leaking user content. + interesting := []string{ + "user-agent", + "x-app", + "anthropic-dangerous-direct-browser-access", + "anthropic-version", + "anthropic-beta", + "x-stainless-lang", + "x-stainless-package-version", + "x-stainless-os", + "x-stainless-arch", + "x-stainless-runtime", + "x-stainless-runtime-version", + "x-stainless-retry-count", + "x-stainless-timeout", + "authorization", + "x-api-key", + "content-type", + "accept", + "x-stainless-helper-method", + } + + h := make([]string, 0, len(interesting)) + for _, k := range interesting { + if v := req.Header.Get(k); v != "" { + h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v))) + } + } + + metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()) + sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body)) + + // Truncate preview to keep logs sane. + if len(sysPreview) > 300 { + sysPreview = sysPreview[:300] + "..." + } + sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n") + sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r") + + aid := int64(0) + aname := "" + if account != nil { + aid = account.ID + aname = account.Name + } + + return fmt.Sprintf( + "url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}", + req.URL.String(), + aid, + aname, + tokenType, + mimicClaudeCode, + metaUserID, + sysPreview, + strings.Join(h, " "), + ) +} + +func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) { + line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode) + if line == "" { + return + } + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebug] %s", line) +} + +func isClaudeCodeCredentialScopeError(msg string) bool { + m := strings.ToLower(strings.TrimSpace(msg)) + if m == "" { + return false + } + return strings.Contains(m, "only authorized for use with claude code") && + strings.Contains(m, "cannot be used for other api requests") +} + +type anthropicCacheControlPayload struct { + Type string `json:"type"` + TTL string `json:"ttl,omitempty"` +} + +type anthropicSystemTextBlockPayload struct { + Type string `json:"type"` + Text string `json:"text"` + CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"` +} + +type anthropicMetadataPayload struct { + UserID string `json:"user_id"` +} + +// replaceModelInBody 替换请求体中的model字段 +// 优先使用定点修改,尽量保持客户端原始字段顺序。 +func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { + return ReplaceModelInBody(body, newModel) +} + +type claudeOAuthNormalizeOptions struct { + injectMetadata bool + metadataUserID string + stripSystemCacheControl bool +} + +// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). +// We intentionally avoid broad keyword replacement in system prompts to prevent +// accidentally changing user-provided instructions. +func sanitizeSystemText(text string) string { + if text == "" { + return text + } + // Some clients include a fixed OpenCode identity sentence. Anthropic may treat + // this as a non-Claude-Code fingerprint, so rewrite it to the canonical + // Claude Code banner before generic "OpenCode"/"opencode" replacements. + text = strings.ReplaceAll( + text, + "You are OpenCode, the best coding agent on the planet.", + strings.TrimSpace(claudeCodeSystemPrompt), + ) + return text +} + +func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) { + block := anthropicSystemTextBlockPayload{ + Type: "text", + Text: text, + } + if includeCacheControl { + block.CacheControl = &anthropicCacheControlPayload{ + Type: "ephemeral", + TTL: claude.DefaultCacheControlTTL, + } + } + return json.Marshal(block) +} + +func marshalAnthropicMetadata(userID string) ([]byte, error) { + return json.Marshal(anthropicMetadataPayload{UserID: userID}) +} + +func buildJSONArrayRaw(items [][]byte) []byte { + if len(items) == 0 { + return []byte("[]") + } + + total := 2 + for _, item := range items { + total += len(item) + } + total += len(items) - 1 + + buf := make([]byte, 0, total) + buf = append(buf, '[') + for i, item := range items { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, item...) + } + buf = append(buf, ']') + return buf +} + +func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) { + next, err := sjson.SetBytes(body, path, value) + if err != nil { + return body, false + } + return next, true +} + +func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) { + next, err := sjson.SetRawBytes(body, path, raw) + if err != nil { + return body, false + } + return next, true +} + +func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) { + next, err := sjson.DeleteBytes(body, path) + if err != nil { + return body, false + } + return next, true +} + +func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) { + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return body, false + } + + out := body + modified := false + + switch { + case sys.Type == gjson.String: + sanitized := sanitizeSystemText(sys.String()) + if sanitized != sys.String() { + if next, ok := setJSONValueBytes(out, "system", sanitized); ok { + out = next + modified = true + } + } + case sys.IsArray(): + index := 0 + sys.ForEach(func(_, item gjson.Result) bool { + if item.Get("type").String() == "text" { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + sanitized := sanitizeSystemText(text) + if sanitized != text { + if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok { + out = next + modified = true + } + } + } + } + + if opts.stripSystemCacheControl && item.Get("cache_control").Exists() { + if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok { + out = next + modified = true + } + } + + index++ + return true + }) + } + + return out, modified +} + +func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) { + if strings.TrimSpace(userID) == "" { + return body, false + } + + metadata := gjson.GetBytes(body, "metadata") + if !metadata.Exists() || metadata.Type == gjson.Null { + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) + } + + trimmedRaw := strings.TrimSpace(metadata.Raw) + if strings.HasPrefix(trimmedRaw, "{") { + existing := metadata.Get("user_id") + if existing.Exists() && existing.Type == gjson.String && existing.String() != "" { + return body, false + } + return setJSONValueBytes(body, "metadata.user_id", userID) + } + + raw, err := marshalAnthropicMetadata(userID) + if err != nil { + return body, false + } + return setJSONRawBytes(body, "metadata", raw) +} + +func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { + if len(body) == 0 { + return body, modelID + } + + out := body + modified := false + + if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed { + out = next + modified = true + } + + rawModel := gjson.GetBytes(out, "model") + if rawModel.Exists() && rawModel.Type == gjson.String { + normalized := claude.NormalizeModelID(rawModel.String()) + if normalized != rawModel.String() { + if next, ok := setJSONValueBytes(out, "model", normalized); ok { + out = next + modified = true + } + modelID = normalized + } + } + + // 确保 tools 字段存在(即使为空数组) + if !gjson.GetBytes(out, "tools").Exists() { + if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok { + out = next + modified = true + } + } + + if opts.injectMetadata && opts.metadataUserID != "" { + if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed { + out = next + modified = true + } + } + + // temperature:真实 Claude Code CLI 总是发送 temperature(默认 1,客户端可覆盖)。 + // 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。 + // 策略:客户端传了什么就透传;没传则补默认 1。 + if !gjson.GetBytes(out, "temperature").Exists() { + if next, ok := setJSONValueBytes(out, "temperature", 1); ok { + out = next + modified = true + } + } + + // max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。 + if !gjson.GetBytes(out, "max_tokens").Exists() { + if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok { + out = next + modified = true + } + } + + // context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动 + // 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。 + // 客户端显式传了就透传;否则按 CLI 行为补齐。 + if !gjson.GetBytes(out, "context_management").Exists() { + thinkingType := gjson.GetBytes(out, "thinking.type").String() + if thinkingType == "enabled" || thinkingType == "adaptive" { + const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}` + if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok { + out = next + modified = true + } + } + } + + // tool_choice:与 Parrot 对齐,不再无条件删除。 + // - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由 + // applyToolNameRewriteToBody 同步映射为假名 + // - 其他形态(auto/any/none)原样透传 + // 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除 + if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 { + if gjson.GetBytes(out, "tool_choice").Exists() { + if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { + out = next + modified = true + } + } + } + + if !modified { + return body, modelID + } + + return out, modelID +} + +func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { + if parsed == nil || account == nil { + return "" + } + if parsed.MetadataUserID != "" { + return "" + } + + userID := strings.TrimSpace(account.GetClaudeUserID()) + if userID == "" && fp != nil { + userID = fp.ClientID + } + if userID == "" { + // Fall back to a random, well-formed client id so we can still satisfy + // Claude Code OAuth requirements when account metadata is incomplete. + userID = generateClientID() + } + + sessionHash := s.GenerateSessionHash(parsed) + sessionID := uuid.NewString() + if sessionHash != "" { + seed := fmt.Sprintf("%d::%s", account.ID, sessionHash) + sessionID = generateSessionUUID(seed) + } + + // 根据指纹 UA 版本选择输出格式 + var uaVersion string + if fp != nil { + uaVersion = ExtractCLIVersion(fp.UserAgent) + } + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) +} + +// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号" +// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。 +// +// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode + +// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层 +// (ForwardAsChatCompletions / ForwardAsResponses) 复用。 +// +// 未抽离之前,OpenAI 协议兼容层仅做 injectClaudeCodePrompt(前置追加), +// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic +// 第三方检测";那条注释就是本函数存在的根因。 +// +// 参数: +// - ctx / c:用于读取指纹和 gateway settings;c 可为 nil(如 count_tokens)。 +// - account:必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。 +// - body:已经 marshal 成 Anthropic /v1/messages 格式的请求体。 +// - systemRaw:body 中原始 system 字段(用于判断是否需要 rewrite)。 +// - model:最终会发给上游的模型 ID(用于 haiku 旁路 + metadata 版本选择)。 +// +// 返回:改写后的 body。即使中间任何一步失败,也会退化成原 body(不会 panic)。 +func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + systemRaw any, + model string, +) []byte { + if account == nil || !account.IsOAuth() || len(body) == 0 { + return body + } + + systemRewritten := false + if !strings.Contains(strings.ToLower(model), "haiku") { + body = rewriteSystemForNonClaudeCode(body, systemRaw) + systemRewritten = true + } + + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} + + if s.identityService != nil && c != nil && c.Request != nil { + if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil { + mimicMPT := false + if s.settingService != nil { + _, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx) + } + if !mimicMPT { + if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" { + normalizeOpts.injectMetadata = true + normalizeOpts.metadataUserID = uid + } + } + } + } + + body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts) + + // Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点 + // 对齐 Parrot transform_request 里剩余的字段级改写。顺序有语义约束: + // 1) messages cache:仅在配置开启时清除客户端断点并注入代理断点 + // 2) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1] + // 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。 + body = s.rewriteMessageCacheControlIfEnabled(ctx, body) + + if rw := buildToolNameRewriteFromBody(body); rw != nil { + body = applyToolNameRewriteToBody(body, rw) + if c != nil { + c.Set(toolNameRewriteKey, rw) + } + } else { + body = applyToolsLastCacheBreakpoint(body) + } + + return body +} + +// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体, +// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。 +// +// 与 buildOAuthMetadataUserID 的唯一区别: +// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。 +// - 如果 body 里已经存在 metadata.user_id,则返回空(由 ensureClaudeOAuthMetadataUserID +// 自行决定是否覆盖)。 +func (s *GatewayService) buildOAuthMetadataUserIDFromBody( + ctx context.Context, + account *Account, + fp *Fingerprint, + body []byte, +) string { + _ = ctx + if account == nil { + return "" + } + if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" { + return "" + } + + userID := strings.TrimSpace(account.GetClaudeUserID()) + if userID == "" && fp != nil { + userID = fp.ClientID + } + if userID == "" { + userID = generateClientID() + } + + sessionID := uuid.NewString() + if hash := hashBodyForSessionSeed(body); hash != "" { + sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash)) + } + + var uaVersion string + if fp != nil { + uaVersion = ExtractCLIVersion(fp.UserAgent) + } + accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) + return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) +} + +// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. +// This mirrors opencode-anthropic-auth behavior: do not trust downstream +// headers when using Claude Code-scoped OAuth credentials. +func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { + if req == nil { + return + } + // Start with the standard defaults (fill missing). + applyClaudeOAuthHeaderDefaults(req) + // Then force key headers to match Claude Code fingerprint regardless of what the client sent. + // 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App") + for key, value := range claude.DefaultHeaders { + if value == "" { + continue + } + setHeaderRaw(req.Header, resolveWireCasing(key), value) + } + // Real Claude CLI uses Accept: application/json (even for streaming). + setHeaderRaw(req.Header, "Accept", "application/json") + if isStream { + setHeaderRaw(req.Header, "x-stainless-helper-method", "stream") + } + // Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。 + // 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。 + if getHeaderRaw(req.Header, "x-client-request-id") == "" { + setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString()) + } +} diff --git a/backend/internal/service/gateway_claude_mimic_test.go b/backend/internal/service/gateway_claude_mimic_test.go new file mode 100644 index 00000000000..458fc366e7b --- /dev/null +++ b/backend/internal/service/gateway_claude_mimic_test.go @@ -0,0 +1,126 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSanitizeSystemText_ReplacesOpenCodeIdentity(t *testing.T) { + // The exact phrase that gets replaced + input := "You are OpenCode, the best coding agent on the planet." + result := sanitizeSystemText(input) + assert.NotContains(t, result, "best coding agent on the planet") +} + +func TestSanitizeSystemText_PreservesNormalText(t *testing.T) { + input := "You are a helpful assistant for coding tasks." + assert.Equal(t, input, sanitizeSystemText(input)) +} + +func TestSanitizeSystemText_EmptyInput(t *testing.T) { + assert.Equal(t, "", sanitizeSystemText("")) +} + +func TestIsClaudeCodeCredentialScopeError_Matches(t *testing.T) { + tests := []struct { + msg string + want bool + }{ + {"This API key is only authorized for use with Claude Code and cannot be used for other API requests", true}, + {"only authorized for use with claude code and cannot be used for other api requests", true}, + {"rate limit exceeded", false}, + {"only authorized for use with claude code", false}, // needs both parts + {"", false}, + } + for _, tt := range tests { + t.Run(tt.msg, func(t *testing.T) { + assert.Equal(t, tt.want, isClaudeCodeCredentialScopeError(tt.msg)) + }) + } +} + +func TestBuildJSONArrayRaw_Empty(t *testing.T) { + result := buildJSONArrayRaw(nil) + assert.Equal(t, []byte("[]"), result) + + result = buildJSONArrayRaw([][]byte{}) + assert.Equal(t, []byte("[]"), result) +} + +func TestBuildJSONArrayRaw_SingleItem(t *testing.T) { + items := [][]byte{[]byte(`{"key":"value"}`)} + result := buildJSONArrayRaw(items) + assert.Equal(t, `[{"key":"value"}]`, string(result)) +} + +func TestBuildJSONArrayRaw_MultipleItems(t *testing.T) { + items := [][]byte{ + []byte(`{"a":1}`), + []byte(`{"b":2}`), + []byte(`{"c":3}`), + } + result := buildJSONArrayRaw(items) + assert.Equal(t, `[{"a":1},{"b":2},{"c":3}]`, string(result)) +} + +func TestSetJSONValueBytes_SetString(t *testing.T) { + body := []byte(`{"model":"old"}`) + result, ok := setJSONValueBytes(body, "model", "new") + require.True(t, ok) + assert.Contains(t, string(result), `"new"`) +} + +func TestSetJSONValueBytes_AddNewField(t *testing.T) { + body := []byte(`{"model":"test"}`) + result, ok := setJSONValueBytes(body, "temperature", 0.7) + require.True(t, ok) + assert.Contains(t, string(result), `"temperature"`) +} + +func TestSetJSONRawBytes_InsertRawJSON(t *testing.T) { + body := []byte(`{"model":"test"}`) + raw := []byte(`[{"type":"text","text":"hello"}]`) + result, ok := setJSONRawBytes(body, "system", raw) + require.True(t, ok) + assert.Contains(t, string(result), `"system"`) +} + +func TestDeleteJSONPathBytes_RemoveField(t *testing.T) { + body := []byte(`{"model":"test","extra":"remove"}`) + result, ok := deleteJSONPathBytes(body, "extra") + require.True(t, ok) + assert.NotContains(t, string(result), `"extra"`) + assert.Contains(t, string(result), `"model"`) +} + +func TestDeleteJSONPathBytes_NonExistentField(t *testing.T) { + body := []byte(`{"model":"test"}`) + result, ok := deleteJSONPathBytes(body, "nonexistent") + // sjson returns the original if path doesn't exist + assert.Equal(t, string(body), string(result)) + _ = ok +} + +func TestMarshalAnthropicSystemTextBlock_WithCache(t *testing.T) { + result, err := marshalAnthropicSystemTextBlock("Hello world", true) + require.NoError(t, err) + assert.Contains(t, string(result), `"text":"Hello world"`) + assert.Contains(t, string(result), `"cache_control"`) +} + +func TestMarshalAnthropicSystemTextBlock_WithoutCache(t *testing.T) { + result, err := marshalAnthropicSystemTextBlock("Hello world", false) + require.NoError(t, err) + assert.Contains(t, string(result), `"text":"Hello world"`) + assert.NotContains(t, string(result), `"cache_control"`) +} + +func TestMarshalAnthropicMetadata(t *testing.T) { + result, err := marshalAnthropicMetadata("user-123") + require.NoError(t, err) + assert.Contains(t, string(result), `"user_id":"user-123"`) +} diff --git a/backend/internal/service/gateway_context_keys.go b/backend/internal/service/gateway_context_keys.go new file mode 100644 index 00000000000..c42614ea817 --- /dev/null +++ b/backend/internal/service/gateway_context_keys.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" +) + +// ForceCacheBillingContextKey 强制缓存计费上下文键 +// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 +type forceCacheBillingKeyType struct{} + +var ForceCacheBillingContextKey = forceCacheBillingKeyType{} + +// IsForceCacheBilling 检查是否启用强制缓存计费 +func IsForceCacheBilling(ctx context.Context) bool { + v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) + return v +} + +// WithForceCacheBilling 返回带有强制缓存计费标记的上下文 +func WithForceCacheBilling(ctx context.Context) context.Context { + return context.WithValue(ctx, ForceCacheBillingContextKey, true) +} diff --git a/backend/internal/service/gateway_count_tokens.go b/backend/internal/service/gateway_count_tokens.go new file mode 100644 index 00000000000..26040c392bf --- /dev/null +++ b/backend/internal/service/gateway_count_tokens.go @@ -0,0 +1,541 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + + "github.com/gin-gonic/gin" +) + +// ForwardCountTokens 转发 count_tokens 请求到上游 API +// 特点:不记录使用量、仅支持非流式响应 +func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { + if parsed == nil { + s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return fmt.Errorf("parse request: empty request") + } + + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + passthroughBody := parsed.Body + if reqModel := parsed.Model; reqModel != "" { + if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + } + } + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) + } + + // Bedrock 不支持 count_tokens 端点 + if account != nil && account.IsBedrock() { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for Bedrock") + return nil + } + + body := parsed.Body + reqModel := parsed.Model + + // Pre-filter: strip empty text blocks to prevent upstream 400. + body = StripEmptyTextBlocks(body) + + isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT + + if shouldMimicClaudeCode { + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + + body = s.rewriteMessageCacheControlIfEnabled(ctx, body) + if rw := buildToolNameRewriteFromBody(body); rw != nil { + body = applyToolNameRewriteToBody(body, rw) + } else { + body = applyToolsLastCacheBreakpoint(body) + } + } + + // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 + if account.Platform == PlatformAntigravity { + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform") + return nil + } + + // 应用模型映射: + // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 + // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) + if reqModel != "" { + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) + if mappedModel != reqModel { + mappingSource = "account" + } + } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + logger.LegacyPrintf("service.gateway", "CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + } + } + + // 获取凭证 + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + + // 构建上游请求 + upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } + } + + // 发送请求 + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + + // 读取响应体 + countTokensTooLarge := func(c *gin.Context) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + } + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) + _ = resp.Body.Close() + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + } + return err + } + + // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) + if resp.StatusCode == 400 && s.shouldRectifySignatureError(ctx, account, respBody) { + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) + + filteredBody := FilterThinkingBlocksForRetry(body) + retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) + if buildErr == nil { + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if retryErr == nil { + resp = retryResp + respBody, err = ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) + _ = resp.Body.Close() + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + } + return err + } + } + } + } + + // 处理错误响应 + if resp.StatusCode >= 400 { + // 标记账号状态(429/529等) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + + // 记录上游错误摘要便于排障(不回显请求内容) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + // 返回简化的错误响应 + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + // 透传成功响应 + c.Data(resp.StatusCode, "application/json", respBody) + return nil +} + +func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx context.Context, c *gin.Context, account *Account, body []byte) error { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + if tokenType != "apikey" { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Invalid account token type") + return fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + upstreamReq, err := s.buildCountTokensRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Passthrough: true, + Kind: "request_error", + Message: sanitizeUpstreamErrorMessage(err.Error()), + }) + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + + countTokensTooLarge := func(c *gin.Context) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + } + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) + _ = resp.Body.Close() + if err != nil { + if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + } + return err + } + + if resp.StatusCode >= 400 { + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。 + // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。 + // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 + if isCountTokensUnsupported404(resp.StatusCode, respBody) { + logger.LegacyPrintf("service.gateway", + "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s", + account.ID, account.Name, truncateString(upstreamMsg, 512)) + s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream") + return nil + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + return nil +} + +func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPICountTokensURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + wireKey := resolveWireCasing(key) + for _, v := range values { + addHeaderRaw(req.Header, wireKey, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + +// buildCountTokensRequest 构建 count_tokens 上游请求 +func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { + // 确定目标 URL + targetURL := claudeAPICountTokensURL + if account.Type == AccountTypeAPIKey { + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" + } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) + } + + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + + // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 + ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false + if s.settingService != nil { + ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx) + } + var ctFingerprint *Fingerprint + if account.IsOAuth() && s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) + if err == nil { + ctFingerprint = fp + if !ctEnableMPT { + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { + body = newBody + } + } + } + } + } + + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if ctFingerprint != nil && ctEnableFP { + body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent) + } + if ctEnableCCH { + body = signBillingHeaderCCH(body) + } + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 设置认证头(保持原始大小写) + if tokenType == "oauth" { + setHeaderRaw(req.Header, "authorization", "Bearer "+token) + } else { + setHeaderRaw(req.Header, "x-api-key", token) + } + + // 白名单透传 headers(恢复真实 wire casing) + for key, values := range clientHeaders { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + wireKey := resolveWireCasing(key) + for _, v := range values { + addHeaderRaw(req.Header, wireKey, v) + } + } + } + + // OAuth 账号:应用指纹到请求头(受设置开关控制) + if ctEnableFP && ctFingerprint != nil { + s.identityService.ApplyFingerprint(req, ctFingerprint) + } + + // 确保必要的 headers 存在(保持原始大小写) + if getHeaderRaw(req.Header, "content-type") == "" { + setHeaderRaw(req.Header, "content-type", "application/json") + } + if getHeaderRaw(req.Header, "anthropic-version") == "" { + setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") + } + if tokenType == "oauth" { + applyClaudeOAuthHeaderDefaults(req) + } + + // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID)) + + // OAuth 账号:处理 anthropic-beta header + if tokenType == "oauth" { + if mimicClaudeCode { + applyClaudeCodeMimicHeaders(req, false) + + incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") + requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting) + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) + } else { + clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") + if clientBetaHeader == "" { + setHeaderRaw(req.Header, "anthropic-beta", claude.CountTokensBetaHeader) + } else { + beta := s.getBetaHeader(modelID, clientBetaHeader) + if !strings.Contains(beta, claude.BetaTokenCounting) { + beta = beta + "," + claude.BetaTokenCounting + } + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) + } + } + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:与 messages 同步的按需 beta 注入(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + setHeaderRaw(req.Header, "anthropic-beta", beta) + } + } + } + } + + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + + return req, nil +} + +// countTokensError 返回 count_tokens 错误响应 +func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} diff --git a/backend/internal/service/gateway_debug.go b/backend/internal/service/gateway_debug.go new file mode 100644 index 00000000000..3bdc8caa5db --- /dev/null +++ b/backend/internal/service/gateway_debug.go @@ -0,0 +1,196 @@ +package service + +import ( + "bytes" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "os" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/tidwall/gjson" +) + +func (s *GatewayService) debugModelRoutingEnabled() bool { + if s == nil { + return false + } + return s.debugModelRouting.Load() +} + +func (s *GatewayService) debugClaudeMimicEnabled() bool { + if s == nil { + return false + } + return s.debugClaudeMimic.Load() +} + +func parseDebugEnvBool(raw string) bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func redactAuthHeaderValue(v string) string { + v = strings.TrimSpace(v) + if v == "" { + return "" + } + // Keep scheme for debugging, redact secret. + if strings.HasPrefix(strings.ToLower(v), "bearer ") { + return "Bearer [redacted]" + } + return "[redacted]" +} + +func safeHeaderValueForLog(key string, v string) string { + key = strings.ToLower(strings.TrimSpace(key)) + switch key { + case "authorization", "x-api-key": + return redactAuthHeaderValue(v) + default: + return strings.TrimSpace(v) + } +} + +func extractSystemPreviewFromBody(body []byte) string { + if len(body) == 0 { + return "" + } + sys := gjson.GetBytes(body, "system") + if !sys.Exists() { + return "" + } + + switch { + case sys.IsArray(): + for _, item := range sys.Array() { + if !item.IsObject() { + continue + } + if strings.EqualFold(item.Get("type").String(), "text") { + if t := item.Get("text").String(); strings.TrimSpace(t) != "" { + return t + } + } + } + return "" + case sys.Type == gjson.String: + return sys.String() + default: + return "" + } +} + +func truncateForLog(b []byte, maxBytes int) string { + if maxBytes <= 0 { + maxBytes = 2048 + } + if len(b) > maxBytes { + b = b[:maxBytes] + } + s := string(b) + // 保持一行,避免污染日志格式 + s = strings.ReplaceAll(s, "\n", "\\n") + s = strings.ReplaceAll(s, "\r", "\\r") + return s +} + +const debugGatewayBodyDefaultFilename = "gateway_debug.log" + +// initDebugGatewayBodyFile 初始化网关调试日志文件。 +// +// - "1"/"true" 等布尔值 → 当前目录下 gateway_debug.log +// - 已有目录路径 → 该目录下 gateway_debug.log +// - 其他 → 视为完整文件路径 +func (s *GatewayService) initDebugGatewayBodyFile(path string) { + if parseDebugEnvBool(path) { + path = debugGatewayBodyDefaultFilename + } + + // 如果 path 指向一个已存在的目录,自动追加默认文件名 + if info, err := os.Stat(path); err == nil && info.IsDir() { + path = filepath.Join(path, debugGatewayBodyDefaultFilename) + } + + // 确保父目录存在 + if dir := filepath.Dir(path); dir != "." { + if err := os.MkdirAll(dir, 0755); err != nil { + slog.Error("failed to create gateway debug log directory", "dir", dir, "error", err) + return + } + } + + f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + if err != nil { + slog.Error("failed to open gateway debug log file", "path", path, "error", err) + return + } + s.debugGatewayBodyFile.Store(f) + slog.Info("gateway debug logging enabled", "path", path) +} + +// debugLogGatewaySnapshot 将网关请求的完整快照(headers + body)写入独立的调试日志文件, +// 用于对比客户端原始请求和上游转发请求。 +// +// 启用方式(环境变量): +// +// SUB2API_DEBUG_GATEWAY_BODY=1 # 写入 gateway_debug.log +// SUB2API_DEBUG_GATEWAY_BODY=/tmp/gateway_debug.log # 写入指定路径 +// +// tag: "CLIENT_ORIGINAL" 或 "UPSTREAM_FORWARD" +func (s *GatewayService) debugLogGatewaySnapshot(tag string, headers http.Header, body []byte, extra map[string]string) { + f := s.debugGatewayBodyFile.Load() + if f == nil { + return + } + + var buf strings.Builder + ts := time.Now().Format("2006-01-02 15:04:05.000") + fmt.Fprintf(&buf, "\n========== [%s] %s ==========\n", ts, tag) + + // 1. context + if len(extra) > 0 { + fmt.Fprint(&buf, "--- context ---\n") + extraKeys := make([]string, 0, len(extra)) + for k := range extra { + extraKeys = append(extraKeys, k) + } + sort.Strings(extraKeys) + for _, k := range extraKeys { + fmt.Fprintf(&buf, " %s: %s\n", k, extra[k]) + } + } + + // 2. headers(按真实 Claude CLI wire 顺序排列,便于与抓包对比;auth 脱敏) + fmt.Fprint(&buf, "--- headers ---\n") + for _, k := range sortHeadersByWireOrder(headers) { + for _, v := range headers[k] { + fmt.Fprintf(&buf, " %s: %s\n", k, safeHeaderValueForLog(k, v)) + } + } + + // 3. body(完整输出,格式化 JSON 便于 diff) + fmt.Fprint(&buf, "--- body ---\n") + if len(body) == 0 { + fmt.Fprint(&buf, " (empty)\n") + } else { + var pretty bytes.Buffer + if json.Indent(&pretty, body, " ", " ") == nil { + fmt.Fprintf(&buf, " %s\n", pretty.Bytes()) + } else { + // JSON 格式化失败时原样输出 + fmt.Fprintf(&buf, " %s\n", body) + } + } + + // 写入文件(调试用,并发写入可能交错但不影响可读性) + _, _ = f.WriteString(buf.String()) +} diff --git a/backend/internal/service/gateway_error_handling.go b/backend/internal/service/gateway_error_handling.go new file mode 100644 index 00000000000..17c5e22fdcb --- /dev/null +++ b/backend/internal/service/gateway_error_handling.go @@ -0,0 +1,466 @@ +package service + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + + "github.com/gin-gonic/gin" +) + +// shouldRectifySignatureError 统一判断是否应触发签名整流(strip thinking blocks 并重试)。 +// 根据账号类型检查对应的开关和匹配模式。 +func (s *GatewayService) shouldRectifySignatureError(ctx context.Context, account *Account, respBody []byte) bool { + if account.Type == AccountTypeAPIKey { + // API Key 账号:独立开关,一次读取配置 + settings, err := s.settingService.GetRectifierSettings(ctx) + if err != nil || !settings.Enabled || !settings.APIKeySignatureEnabled { + return false + } + // 先检查内置模式(同 OAuth),再检查自定义关键词 + if s.isThinkingBlockSignatureError(respBody) { + return true + } + return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns) + } + // OAuth/SetupToken/Upstream/Bedrock 等:保持原有行为(内置模式 + 原开关) + return s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) +} + +// isSignatureErrorPattern 仅做模式匹配,不检查开关。 +// 用于已进入重试流程后的二阶段检测(此时开关已在首次调用时验证过)。 +func (s *GatewayService) isSignatureErrorPattern(ctx context.Context, account *Account, respBody []byte) bool { + if s.isThinkingBlockSignatureError(respBody) { + return true + } + if account.Type == AccountTypeAPIKey { + settings, err := s.settingService.GetRectifierSettings(ctx) + if err != nil { + return false + } + return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns) + } + return false +} + +// matchSignaturePatterns 检查响应体是否匹配自定义关键词列表(不区分大小写)。 +func matchSignaturePatterns(respBody []byte, patterns []string) bool { + if len(patterns) == 0 { + return false + } + bodyLower := strings.ToLower(string(respBody)) + for _, p := range patterns { + p = strings.TrimSpace(p) + if p == "" { + continue + } + if strings.Contains(bodyLower, strings.ToLower(p)) { + return true + } + } + return false +} + +// isThinkingBlockSignatureError 检测是否是thinking block相关错误 +// 这类错误可以通过过滤thinking blocks并重试来解决 +func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 检测signature相关的错误(更宽松的匹配) + // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 + if strings.Contains(msg, "signature") { + return true + } + + // 检测 thinking block 顺序/类型错误 + // 例如: "Expected `thinking` or `redacted_thinking`, but found `text`" + if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block type error") + return true + } + + // 检测 thinking block 被修改的错误 + // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" + if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block modification error") + return true + } + + // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block) + // 例如: "all messages must have non-empty content" + // "messages: text content blocks must be non-empty" + if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") || + strings.Contains(msg, "content blocks must be non-empty") { + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") + return true + } + + return false +} + +func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { + // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 + // 默认保守:无法识别则不切换。 + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) + if msg == "" { + return false + } + + // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 + // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 + if strings.Contains(msg, "anthropic-beta") || + strings.Contains(msg, "beta feature") || + strings.Contains(msg, "requires beta") { + return true + } + + // thinking/tool streaming 等兼容性约束(常见于中间转换链路) + if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { + return true + } + if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { + return true + } + + return false +} + +// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 +// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} +func ExtractUpstreamErrorMessage(body []byte) string { + return extractUpstreamErrorMessage(body) +} + +func extractUpstreamErrorMessage(body []byte) string { + // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} + if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { + inner := strings.TrimSpace(m) + // 有些上游会把完整 JSON 作为字符串塞进 message + if strings.HasPrefix(inner, "{") { + if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { + return innerMsg + } + } + return m + } + + // ChatGPT 内部 API 风格:{"detail":"..."} + if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" { + return d + } + + // 兜底:尝试顶层 message + return gjson.GetBytes(body, "message").String() +} + +func extractUpstreamErrorCode(body []byte) string { + if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" { + return code + } + + inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) + if !strings.HasPrefix(inner, "{") { + return "" + } + + if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" { + return code + } + + if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 { + if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" { + return code + } + } + + return "" +} + +func isCountTokensUnsupported404(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body))) + if msg == "" { + return false + } + if strings.Contains(msg, "/v1/messages/count_tokens") { + return true + } + return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") +} + +func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 调试日志:打印上游错误响应 + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + // Print a compact upstream request fingerprint when we hit the Claude Code OAuth + // credential scope error. This avoids requiring env-var tweaks in a fixed deploy. + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + + // Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet. + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + // 处理上游错误,标记账号状态 + shouldDisable := false + if s.rateLimitService != nil { + shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + if shouldDisable { + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} + } + + // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "Upstream error %d (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + // 非 failover 错误也支持错误透传规则匹配。 + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary) + } + + // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) + var errType, errMsg string + var statusCode int + + switch resp.StatusCode { + case 400: + c.Data(http.StatusBadRequest, "application/json", body) + summary := upstreamMsg + if summary == "" { + summary = truncateForLog(body, 512) + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, summary) + case 401: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream authentication failed, please contact administrator" + case 403: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream access forbidden, please contact administrator" + case 429: + statusCode = http.StatusTooManyRequests + errType = "rate_limit_error" + errMsg = "Upstream rate limit exceeded, please retry later" + case 529: + statusCode = http.StatusServiceUnavailable + errType = "overloaded_error" + errMsg = "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream service temporarily unavailable" + default: + statusCode = http.StatusBadGateway + errType = "upstream_error" + errMsg = "Upstream request failed" + } + + // 返回自定义错误响应 + c.JSON(statusCode, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + statusCode := resp.StatusCode + + // OAuth/Setup Token 账号的 403:标记账号异常 + if account.IsOAuth() && statusCode == 403 { + s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) + logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) + } else { + // API Key 未配置错误码:不标记账号状态 + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) + } +} + +func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) +} + +// handleRetryExhaustedError 处理重试耗尽后的错误 +// OAuth 403:标记账号异常 +// API Key 未配置错误码:仅返回错误,不标记账号 +func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { + // Capture upstream error body before side-effects consume the stream. + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + + if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { + if v, ok := c.Get(claudeMimicDebugInfoKey); ok { + if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + resp.StatusCode, + resp.Header.Get("x-request-id"), + line, + ) + } + } + } + + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry_exhausted", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s", + resp.StatusCode, + account.ID, + account.Platform, + account.Type, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed after retries", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary) + } + + // 返回统一的重试耗尽错误响应 + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed after retries", + }, + }) + + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (retries exhausted) message=%s", resp.StatusCode, upstreamMsg) +} diff --git a/backend/internal/service/gateway_error_handling_test.go b/backend/internal/service/gateway_error_handling_test.go new file mode 100644 index 00000000000..611ab2a437c --- /dev/null +++ b/backend/internal/service/gateway_error_handling_test.go @@ -0,0 +1,139 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMatchSignaturePatterns_EmptyPatterns(t *testing.T) { + body := []byte(`{"error":{"message":"signature error"}}`) + assert.False(t, matchSignaturePatterns(body, nil)) + assert.False(t, matchSignaturePatterns(body, []string{})) +} + +func TestMatchSignaturePatterns_MatchesCaseInsensitive(t *testing.T) { + body := []byte(`{"error":{"message":"Invalid Signature detected in response"}}`) + assert.True(t, matchSignaturePatterns(body, []string{"invalid signature"})) + assert.True(t, matchSignaturePatterns(body, []string{"INVALID SIGNATURE"})) +} + +func TestMatchSignaturePatterns_NoMatch(t *testing.T) { + body := []byte(`{"error":{"message":"rate limit exceeded"}}`) + assert.False(t, matchSignaturePatterns(body, []string{"signature", "thinking block"})) +} + +func TestMatchSignaturePatterns_MultiplePatterns(t *testing.T) { + body := []byte(`{"error":{"message":"thinking block order violation"}}`) + assert.True(t, matchSignaturePatterns(body, []string{"signature", "thinking block order"})) +} + +func TestMatchSignaturePatterns_EmptyBody(t *testing.T) { + assert.False(t, matchSignaturePatterns(nil, []string{"error"})) + assert.False(t, matchSignaturePatterns([]byte{}, []string{"error"})) +} + +func TestIsThinkingBlockSignatureError_SignatureError(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"error":{"message":"could not validate the signature on one of the thinking blocks"}}`) + assert.True(t, svc.isThinkingBlockSignatureError(body)) +} + +func TestIsThinkingBlockSignatureError_ThinkingBlockTypeError(t *testing.T) { + svc := &GatewayService{} + // "Expected `thinking` or `redacted_thinking`, but found `text`" + body := []byte(`{"error":{"message":"Expected thinking or redacted_thinking, but found text"}}`) + assert.True(t, svc.isThinkingBlockSignatureError(body)) +} + +func TestIsThinkingBlockSignatureError_NoMatch(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"error":{"message":"rate limit exceeded"}}`) + assert.False(t, svc.isThinkingBlockSignatureError(body)) +} + +func TestIsThinkingBlockSignatureError_EmptyBody(t *testing.T) { + svc := &GatewayService{} + assert.False(t, svc.isThinkingBlockSignatureError(nil)) + assert.False(t, svc.isThinkingBlockSignatureError([]byte{})) +} + +func TestShouldFailoverOn400_BetaRequired(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"error":{"message":"header 'anthropic-beta' must include 'interleaved-thinking-2025-05-14'"}}`) + assert.True(t, svc.shouldFailoverOn400(body)) +} + +func TestShouldFailoverOn400_ThinkingToolIncompatible(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"error":{"message":"thinking is not supported with tool_choice"}}`) + assert.True(t, svc.shouldFailoverOn400(body)) +} + +func TestShouldFailoverOn400_NormalError(t *testing.T) { + svc := &GatewayService{} + body := []byte(`{"error":{"message":"invalid model specified"}}`) + assert.False(t, svc.shouldFailoverOn400(body)) +} + +func TestExtractUpstreamErrorMessage_ClaudeFormat(t *testing.T) { + body := []byte(`{"error":{"type":"invalid_request_error","message":"model not found"}}`) + assert.Equal(t, "model not found", ExtractUpstreamErrorMessage(body)) +} + +func TestExtractUpstreamErrorMessage_OpenAIFormat(t *testing.T) { + body := []byte(`{"error":{"message":"Rate limit reached","type":"rate_limit_error"}}`) + assert.Equal(t, "Rate limit reached", ExtractUpstreamErrorMessage(body)) +} + +func TestExtractUpstreamErrorMessage_EmptyBody(t *testing.T) { + assert.Equal(t, "", ExtractUpstreamErrorMessage(nil)) + assert.Equal(t, "", ExtractUpstreamErrorMessage([]byte{})) +} + +func TestExtractUpstreamErrorMessage_MalformedJSON(t *testing.T) { + assert.Equal(t, "", ExtractUpstreamErrorMessage([]byte(`not json`))) +} + +func TestExtractUpstreamErrorCode_FromCodeField(t *testing.T) { + body := []byte(`{"error":{"code":"overloaded_error","message":"overloaded"}}`) + require.Equal(t, "overloaded_error", extractUpstreamErrorCode(body)) +} + +func TestExtractUpstreamErrorCode_NoCodeField(t *testing.T) { + body := []byte(`{"error":{"type":"overloaded_error","message":"overloaded"}}`) + assert.Equal(t, "", extractUpstreamErrorCode(body)) +} + +func TestExtractUpstreamErrorCode_NestedJSON(t *testing.T) { + body := []byte(`{"error":{"message":"{\"error\":{\"type\":\"inner_error\"}}"}}`) + code := extractUpstreamErrorCode(body) + assert.Contains(t, []string{"inner_error", ""}, code) +} + +func TestExtractUpstreamErrorCode_NoError(t *testing.T) { + body := []byte(`{"result":"ok"}`) + assert.Equal(t, "", extractUpstreamErrorCode(body)) +} + +func TestIsCountTokensUnsupported404_Extended(t *testing.T) { + tests := []struct { + name string + status int + body string + want bool + }{ + {"404 with count_tokens not found", 404, `{"error":{"message":"count_tokens endpoint not found"}}`, true}, + {"404 without matching type", 404, `{"error":{"type":"invalid_request"}}`, false}, + {"200 with not_found_error", 200, `{"error":{"type":"not_found_error"}}`, false}, + {"404 empty body", 404, `{}`, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, isCountTokensUnsupported404(tt.status, []byte(tt.body))) + }) + } +} diff --git a/backend/internal/service/gateway_errors.go b/backend/internal/service/gateway_errors.go new file mode 100644 index 00000000000..a863d4bc92e --- /dev/null +++ b/backend/internal/service/gateway_errors.go @@ -0,0 +1,37 @@ +package service + +import ( + "context" + "errors" + "net/http" +) + +// ErrNoAvailableAccounts 表示没有可用的账号 +var ErrNoAvailableAccounts = errors.New("no available accounts") + +// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 +var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") + +// UpstreamFailoverError indicates an upstream error that should trigger account failover. +type UpstreamFailoverError struct { + StatusCode int + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 +} + +// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。 +// 由 handler 层在同账号重试全部用尽、切换账号时调用。 +func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) { + if failoverErr == nil || !failoverErr.RetryableOnSameAccount { + return + } + // 根据状态码选择封禁策略 + switch failoverErr.StatusCode { + case http.StatusBadRequest: + tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]") + case http.StatusBadGateway: + tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]") + } +} diff --git a/backend/internal/service/gateway_forward_helpers.go b/backend/internal/service/gateway_forward_helpers.go new file mode 100644 index 00000000000..10f59b35cf5 --- /dev/null +++ b/backend/internal/service/gateway_forward_helpers.go @@ -0,0 +1,254 @@ +package service + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" +) + +// ───────────────────────────────────────────────────────────────────────────── +// Forward 函数的子方法:将请求转发流程按阶段组织 +// ───────────────────────────────────────────────────────────────────────────── + +// forwardContext 封装 Forward 流程中各阶段共享的中间状态。 +type forwardContext struct { + body []byte // 转换后的请求体 + originalModel string // 原始模型名(用于计费和日志) + mappedModel string // 映射后的模型名(发送到上游) + reqModel string // 当前使用的模型名 + reqStream bool // 是否流式请求 + mimicClaudeCode bool // 是否启用 Claude Code 伪装 + token string // 认证凭证 + tokenType string // 凭证类型 (oauth/apikey/setup-token) + proxyURL string // 代理 URL + tlsProfile *tlsfingerprint.Profile // TLS 指纹配置 + startTime time.Time +} + +// checkForwardEarlyRoutes 检查是否命中快速退出路径(web search、passthrough、bedrock)。 +// 返回非nil结果表示已完成转发,主函数应直接返回。 +func (s *GatewayService) checkForwardEarlyRoutes(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*ForwardResult, error) { + if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) { + return s.handleWebSearchEmulation(ctx, c, account, parsed) + } + + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + passthroughBody := parsed.Body + passthroughModel := parsed.Model + if passthroughModel != "" { + if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel { + passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) + logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) + passthroughModel = mappedModel + } + } + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: passthroughBody, + RequestModel: passthroughModel, + OriginalModel: parsed.Model, + RequestStream: parsed.Stream, + StartTime: startTime, + }) + } + + if account != nil && account.IsBedrock() { + return s.forwardBedrock(ctx, c, account, parsed, startTime) + } + + return nil, nil +} + +// prepareForwardBody 执行请求体转换并获取凭证。 +// 包括:Beta 策略评估、Claude Code 伪装、模型映射、cache 控制、凭证获取。 +func (s *GatewayService) prepareForwardBody(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest, startTime time.Time) (*forwardContext, error) { + // Beta 策略评估 + if account.Platform == PlatformAnthropic && c != nil { + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model) + if policy.blockErr != nil { + return nil, policy.blockErr + } + filterSet := policy.filterSet + if filterSet == nil { + filterSet = map[string]struct{}{} + } + c.Set(betaPolicyFilterSetKey, filterSet) + } + + body := parsed.Body + reqModel := parsed.Model + reqStream := parsed.Stream + originalModel := reqModel + + if c != nil { + s.debugLogGatewaySnapshot("CLIENT_ORIGINAL", c.Request.Header, body, map[string]string{ + "account": fmt.Sprintf("%d(%s)", account.ID, account.Name), + "account_type": string(account.Type), + "model": reqModel, + "stream": strconv.FormatBool(reqStream), + }) + } + + // Claude Code 伪装判定 + isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) + shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode + + if shouldMimicClaudeCode { + systemRewritten := false + if !strings.Contains(strings.ToLower(reqModel), "haiku") { + body = rewriteSystemForNonClaudeCode(body, parsed.System) + systemRewritten = true + } + + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} + if s.identityService != nil { + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) + if err == nil && fp != nil { + _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx) + if !mimicMPT { + if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { + normalizeOpts.injectMetadata = true + normalizeOpts.metadataUserID = metadataUserID + } + } + } + } + + body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + + body = s.rewriteMessageCacheControlIfEnabled(ctx, body) + if rw := buildToolNameRewriteFromBody(body); rw != nil { + body = applyToolNameRewriteToBody(body, rw) + c.Set(toolNameRewriteKey, rw) + } else { + body = applyToolsLastCacheBreakpoint(body) + } + } + + body = enforceCacheControlLimit(body) + + // 模型映射 + mappedModel := reqModel + mappingSource := "" + if account.Type == AccountTypeAPIKey { + mappedModel = account.GetMappedModel(reqModel) + if mappedModel != reqModel { + mappingSource = "account" + } + } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + if candidate, matched := account.ResolveMappedModel(reqModel); matched { + mappedModel = candidate + mappingSource = "account" + } else { + normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel)) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "vertex" + } + } + } + if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + normalized := claude.NormalizeModelID(reqModel) + if normalized != reqModel { + mappedModel = normalized + mappingSource = "prefix" + } + } + if mappedModel != reqModel { + body = s.replaceModelInBody(body, mappedModel) + reqModel = mappedModel + logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + } + + if s.shouldInjectAnthropicCacheTTL1h(ctx, account) { + body = injectAnthropicCacheControlTTL1h(body) + } + + // 获取凭证 + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + // 解析代理 + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } + } + + tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account) + + logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", + account.ID, account.Name, account.Platform, account.Type, tlsProfile, proxyURL) + + body = StripEmptyTextBlocks(body) + setOpsUpstreamRequestBody(c, body) + + return &forwardContext{ + body: body, + originalModel: originalModel, + mappedModel: mappedModel, + reqModel: reqModel, + reqStream: reqStream, + mimicClaudeCode: shouldMimicClaudeCode, + token: token, + tokenType: tokenType, + proxyURL: proxyURL, + tlsProfile: tlsProfile, + startTime: startTime, + }, nil +} + +// processForwardResponse 处理上游响应并构建最终结果。 +// 处理流式/非流式响应,构建 ForwardResult。 +func (s *GatewayService) processForwardResponse(ctx context.Context, c *gin.Context, account *Account, resp *http.Response, fc *forwardContext, parsed *ParsedRequest) (*ForwardResult, error) { + // 触发上游接受回调 + if parsed.OnUpstreamAccepted != nil { + parsed.OnUpstreamAccepted() + } + + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + var err error + + if fc.reqStream { + streamResult, streamErr := s.handleStreamingResponse(ctx, resp, c, account, fc.startTime, fc.originalModel, fc.reqModel, fc.mimicClaudeCode) + if streamErr != nil { + if streamErr.Error() == "have error in stream" { + return nil, &UpstreamFailoverError{StatusCode: 403} + } + return nil, streamErr + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, fc.originalModel, fc.reqModel) + if err != nil { + return nil, err + } + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: fc.originalModel, + UpstreamModel: fc.mappedModel, + Stream: fc.reqStream, + Duration: time.Since(fc.startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} diff --git a/backend/internal/service/gateway_group_rate.go b/backend/internal/service/gateway_group_rate.go new file mode 100644 index 00000000000..75d8a7621c6 --- /dev/null +++ b/backend/internal/service/gateway_group_rate.go @@ -0,0 +1,154 @@ +package service + +import ( + "context" + "fmt" + "sort" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return defaultUserGroupRateCacheTTL + } + return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second +} + +func resolveModelsListCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 { + return defaultModelsListCacheTTL + } + return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second +} + +func modelsListCacheKey(groupID *int64, platform string) string { + return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) +} + +func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if s == nil { + return groupDefaultMultiplier + } + resolver := s.userGroupRateResolver + if resolver == nil { + resolver = newUserGroupRateResolver( + s.userGroupRateRepo, + s.userGroupRateCache, + resolveUserGroupRateCacheTTL(s.cfg), + &s.userGroupRateSF, + "service.gateway", + ) + } + return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier) +} + +// GetAvailableModels returns the list of models available for a group +// It aggregates model_mapping keys from all schedulable accounts in the group +func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { + cacheKey := modelsListCacheKey(groupID, platform) + if s.modelsListCache != nil { + if cached, found := s.modelsListCache.Get(cacheKey); found { + if models, ok := cached.([]string); ok { + modelsListCacheHitTotal.Add(1) + return cloneStringSlice(models) + } + } + } + modelsListCacheMissTotal.Add(1) + + var accounts []Account + var err error + + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListSchedulable(ctx) + } + + if err != nil || len(accounts) == 0 { + return nil + } + + // Filter by platform if specified + if platform != "" { + filtered := make([]Account, 0) + for _, acc := range accounts { + if acc.Platform == platform { + filtered = append(filtered, acc) + } + } + accounts = filtered + } + + // Collect unique models from all accounts + modelSet := make(map[string]struct{}) + hasAnyMapping := false + + for _, acc := range accounts { + mapping := acc.GetModelMapping() + if len(mapping) > 0 { + hasAnyMapping = true + for model := range mapping { + modelSet[model] = struct{}{} + } + } + } + + // If no account has model_mapping, return nil (use default) + if !hasAnyMapping { + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return nil + } + + // Convert to slice + models := make([]string, 0, len(modelSet)) + for model := range modelSet { + models = append(models, model) + } + sort.Strings(models) + + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return cloneStringSlice(models) +} + +func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) { + if s == nil || s.modelsListCache == nil { + return + } + + normalizedPlatform := strings.TrimSpace(platform) + // 完整匹配时精准失效;否则按维度批量失效。 + if groupID != nil && normalizedPlatform != "" { + s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform)) + return + } + + targetGroup := derefGroupID(groupID) + for key := range s.modelsListCache.Items() { + parts := strings.SplitN(key, "|", 2) + if len(parts) != 2 { + continue + } + groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64) + if parseErr != nil { + continue + } + if groupID != nil && groupPart != targetGroup { + continue + } + if normalizedPlatform != "" && parts[1] != normalizedPlatform { + continue + } + s.modelsListCache.Delete(key) + } +} diff --git a/backend/internal/service/gateway_metrics.go b/backend/internal/service/gateway_metrics.go new file mode 100644 index 00000000000..d59b70d329e --- /dev/null +++ b/backend/internal/service/gateway_metrics.go @@ -0,0 +1,43 @@ +package service + +import ( + "sync/atomic" +) + +var ( + windowCostPrefetchCacheHitTotal atomic.Int64 + windowCostPrefetchCacheMissTotal atomic.Int64 + windowCostPrefetchBatchSQLTotal atomic.Int64 + windowCostPrefetchFallbackTotal atomic.Int64 + windowCostPrefetchErrorTotal atomic.Int64 + + userGroupRateCacheHitTotal atomic.Int64 + userGroupRateCacheMissTotal atomic.Int64 + userGroupRateCacheLoadTotal atomic.Int64 + userGroupRateCacheSFSharedTotal atomic.Int64 + userGroupRateCacheFallbackTotal atomic.Int64 + + modelsListCacheHitTotal atomic.Int64 + modelsListCacheMissTotal atomic.Int64 + modelsListCacheStoreTotal atomic.Int64 +) + +func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { + return windowCostPrefetchCacheHitTotal.Load(), + windowCostPrefetchCacheMissTotal.Load(), + windowCostPrefetchBatchSQLTotal.Load(), + windowCostPrefetchFallbackTotal.Load(), + windowCostPrefetchErrorTotal.Load() +} + +func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) { + return userGroupRateCacheHitTotal.Load(), + userGroupRateCacheMissTotal.Load(), + userGroupRateCacheLoadTotal.Load(), + userGroupRateCacheSFSharedTotal.Load(), + userGroupRateCacheFallbackTotal.Load() +} + +func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { + return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() +} diff --git a/backend/internal/service/gateway_scheduling_utils.go b/backend/internal/service/gateway_scheduling_utils.go new file mode 100644 index 00000000000..d4bd9037152 --- /dev/null +++ b/backend/internal/service/gateway_scheduling_utils.go @@ -0,0 +1,99 @@ +package service + +import ( + "context" + "sort" +) + +// ───────────────────────────────────────────────────────────────────────────── +// 调度共享工具:消除账号选择中的重复模式 +// ───────────────────────────────────────────────────────────────────────────── + +// eligibilityOpts 账号调度资格检查的选项。 +type eligibilityOpts struct { + platform string // 目标平台 + useMixed bool // 是否混合调度模式 + requestedModel string // 请求的模型名 + isSticky bool // 是否粘性路径(影响 WindowCost/RPM 阈值宽松度) +} + +// isAccountEligibleForScheduling 执行完整的 7 项调度资格检查。 +// 统一了 tryModelRoutingSelection / selectByLoadBalance / tryStickySessionSelection 中 +// 重复出现的门控检查逻辑。isSticky=true 时 WindowCost/RPM 使用宽松阈值。 +func (s *GatewayService) isAccountEligibleForScheduling(ctx context.Context, acc *Account, opts eligibilityOpts) bool { + if !s.isAccountSchedulableForSelection(acc) { + return false + } + if !s.isAccountAllowedForPlatform(acc, opts.platform, opts.useMixed) { + return false + } + if opts.requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, opts.requestedModel) { + return false + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, opts.requestedModel) { + return false + } + if !s.isAccountSchedulableForQuota(acc) { + return false + } + if !s.isAccountSchedulableForWindowCost(ctx, acc, opts.isSticky) { + return false + } + if !s.isAccountSchedulableForRPM(ctx, acc, opts.isSticky) { + return false + } + return true +} + +// isBetterAccountCandidate 判断 candidate 是否优于 current(优先级→LRU→OAuth偏好)。 +// oauthPlatformFilter 非空时,仅在两者都属于该平台时才应用 OAuth 偏好。 +func isBetterAccountCandidate(candidate, current *Account, preferOAuth bool, oauthPlatformFilter string) bool { + if candidate.Priority < current.Priority { + return true + } + if candidate.Priority > current.Priority { + return false + } + // 同优先级:比较 LastUsedAt(nil 视为最优) + switch { + case candidate.LastUsedAt == nil && current.LastUsedAt != nil: + return true + case candidate.LastUsedAt != nil && current.LastUsedAt == nil: + return false + case candidate.LastUsedAt == nil && current.LastUsedAt == nil: + if preferOAuth && candidate.Type != current.Type && candidate.Type == AccountTypeOAuth { + if oauthPlatformFilter == "" { + return true + } + return candidate.Platform == oauthPlatformFilter && current.Platform == oauthPlatformFilter + } + return false + default: + return candidate.LastUsedAt.Before(*current.LastUsedAt) + } +} + +// sortAccountsWithLoadByPriority 对带负载信息的账号列表按优先级→负载率→LRU排序, +// 并在同组内随机打乱以防止热点。 +func sortAccountsWithLoadByPriority(accounts []accountWithLoad) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.account.Priority != b.account.Priority { + return a.account.Priority < b.account.Priority + } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return a.loadInfo.LoadRate < b.loadInfo.LoadRate + } + switch { + case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: + return true + case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: + return false + case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: + return false + default: + return a.account.LastUsedAt.Before(*b.account.LastUsedAt) + } + }) + shuffleWithinSortGroups(accounts) +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6151d78ecde..74e92cda008 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -4,24 +4,18 @@ import ( "bufio" "bytes" "context" - "crypto/sha256" "encoding/json" "errors" "fmt" "io" "log/slog" mathrand "math/rand" - "net" "net/http" - "net/url" "os" - "path/filepath" "regexp" "sort" - "strconv" "strings" "sync/atomic" - "syscall" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -30,9 +24,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" - "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" - "github.com/cespare/xxhash/v2" - "github.com/google/uuid" gocache "github.com/patrickmn/go-cache" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -67,86 +58,6 @@ const ( cacheTTLTarget1h = "1h" ) -// ForceCacheBillingContextKey 强制缓存计费上下文键 -// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 -type forceCacheBillingKeyType struct{} - -// accountWithLoad 账号与负载信息的组合,用于负载感知调度 -type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo -} - -var ForceCacheBillingContextKey = forceCacheBillingKeyType{} - -var ( - windowCostPrefetchCacheHitTotal atomic.Int64 - windowCostPrefetchCacheMissTotal atomic.Int64 - windowCostPrefetchBatchSQLTotal atomic.Int64 - windowCostPrefetchFallbackTotal atomic.Int64 - windowCostPrefetchErrorTotal atomic.Int64 - - userGroupRateCacheHitTotal atomic.Int64 - userGroupRateCacheMissTotal atomic.Int64 - userGroupRateCacheLoadTotal atomic.Int64 - userGroupRateCacheSFSharedTotal atomic.Int64 - userGroupRateCacheFallbackTotal atomic.Int64 - - modelsListCacheHitTotal atomic.Int64 - modelsListCacheMissTotal atomic.Int64 - modelsListCacheStoreTotal atomic.Int64 -) - -func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { - return windowCostPrefetchCacheHitTotal.Load(), - windowCostPrefetchCacheMissTotal.Load(), - windowCostPrefetchBatchSQLTotal.Load(), - windowCostPrefetchFallbackTotal.Load(), - windowCostPrefetchErrorTotal.Load() -} - -func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) { - return userGroupRateCacheHitTotal.Load(), - userGroupRateCacheMissTotal.Load(), - userGroupRateCacheLoadTotal.Load(), - userGroupRateCacheSFSharedTotal.Load(), - userGroupRateCacheFallbackTotal.Load() -} - -func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { - return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() -} - -func openAIStreamEventIsTerminal(data string) bool { - trimmed := strings.TrimSpace(data) - if trimmed == "" { - return false - } - if trimmed == "[DONE]" { - return true - } - switch gjson.Get(trimmed, "type").String() { - case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": - return true - default: - return false - } -} - -func anthropicStreamEventIsTerminal(eventName, data string) bool { - if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") { - return true - } - trimmed := strings.TrimSpace(data) - if trimmed == "" { - return false - } - if trimmed == "[DONE]" { - return true - } - return gjson.Get(trimmed, "type").String() == "message_stop" -} - func cloneStringSlice(src []string) []string { if len(src) == 0 { return nil @@ -156,182 +67,6 @@ func cloneStringSlice(src []string) []string { return dst } -// IsForceCacheBilling 检查是否启用强制缓存计费 -func IsForceCacheBilling(ctx context.Context) bool { - v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) - return v -} - -// WithForceCacheBilling 返回带有强制缓存计费标记的上下文 -func WithForceCacheBilling(ctx context.Context) context.Context { - return context.WithValue(ctx, ForceCacheBillingContextKey, true) -} - -func (s *GatewayService) debugModelRoutingEnabled() bool { - if s == nil { - return false - } - return s.debugModelRouting.Load() -} - -func (s *GatewayService) debugClaudeMimicEnabled() bool { - if s == nil { - return false - } - return s.debugClaudeMimic.Load() -} - -func parseDebugEnvBool(raw string) bool { - switch strings.ToLower(strings.TrimSpace(raw)) { - case "1", "true", "yes", "on": - return true - default: - return false - } -} - -func shortSessionHash(sessionHash string) string { - if sessionHash == "" { - return "" - } - if len(sessionHash) <= 8 { - return sessionHash - } - return sessionHash[:8] -} - -func redactAuthHeaderValue(v string) string { - v = strings.TrimSpace(v) - if v == "" { - return "" - } - // Keep scheme for debugging, redact secret. - if strings.HasPrefix(strings.ToLower(v), "bearer ") { - return "Bearer [redacted]" - } - return "[redacted]" -} - -func safeHeaderValueForLog(key string, v string) string { - key = strings.ToLower(strings.TrimSpace(key)) - switch key { - case "authorization", "x-api-key": - return redactAuthHeaderValue(v) - default: - return strings.TrimSpace(v) - } -} - -func extractSystemPreviewFromBody(body []byte) string { - if len(body) == 0 { - return "" - } - sys := gjson.GetBytes(body, "system") - if !sys.Exists() { - return "" - } - - switch { - case sys.IsArray(): - for _, item := range sys.Array() { - if !item.IsObject() { - continue - } - if strings.EqualFold(item.Get("type").String(), "text") { - if t := item.Get("text").String(); strings.TrimSpace(t) != "" { - return t - } - } - } - return "" - case sys.Type == gjson.String: - return sys.String() - default: - return "" - } -} - -func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string { - if req == nil { - return "" - } - - // Only log a minimal fingerprint to avoid leaking user content. - interesting := []string{ - "user-agent", - "x-app", - "anthropic-dangerous-direct-browser-access", - "anthropic-version", - "anthropic-beta", - "x-stainless-lang", - "x-stainless-package-version", - "x-stainless-os", - "x-stainless-arch", - "x-stainless-runtime", - "x-stainless-runtime-version", - "x-stainless-retry-count", - "x-stainless-timeout", - "authorization", - "x-api-key", - "content-type", - "accept", - "x-stainless-helper-method", - } - - h := make([]string, 0, len(interesting)) - for _, k := range interesting { - if v := req.Header.Get(k); v != "" { - h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v))) - } - } - - metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()) - sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body)) - - // Truncate preview to keep logs sane. - if len(sysPreview) > 300 { - sysPreview = sysPreview[:300] + "..." - } - sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n") - sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r") - - aid := int64(0) - aname := "" - if account != nil { - aid = account.ID - aname = account.Name - } - - return fmt.Sprintf( - "url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}", - req.URL.String(), - aid, - aname, - tokenType, - mimicClaudeCode, - metaUserID, - sysPreview, - strings.Join(h, " "), - ) -} - -func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) { - line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode) - if line == "" { - return - } - logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebug] %s", line) -} - -func isClaudeCodeCredentialScopeError(msg string) bool { - m := strings.ToLower(strings.TrimSpace(msg)) - if m == "" { - return false - } - return strings.Contains(m, "only authorized for use with claude code") && - strings.Contains(m, "cannot be used for other api requests") -} - // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( @@ -349,12 +84,6 @@ var ( } ) -// ErrNoAvailableAccounts 表示没有可用的账号 -var ErrNoAvailableAccounts = errors.New("no available accounts") - -// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 -var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") - // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -408,59 +137,6 @@ func derefGroupID(groupID *int64) int64 { return *groupID } -func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration { - if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 { - return defaultUserGroupRateCacheTTL - } - return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second -} - -func resolveModelsListCacheTTL(cfg *config.Config) time.Duration { - if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 { - return defaultModelsListCacheTTL - } - return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second -} - -func modelsListCacheKey(groupID *int64, platform string) string { - return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) -} - -func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { - return PrefetchedStickyGroupIDFromContext(ctx) -} - -func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { - prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) - if !ok || prefetchedGroupID != derefGroupID(groupID) { - return 0 - } - if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 { - return accountID - } - return 0 -} - -// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 -// 委托 IsSchedulable() 判断账号级可调度性(状态、配额、过载、限流等), -// 额外检查模型级限流。 -// -// shouldClearStickySession checks if an account is in an unschedulable state -// and the sticky session binding should be cleared. -// Delegates to IsSchedulable() for account-level checks, plus model-level rate limiting. -func shouldClearStickySession(account *Account, requestedModel string) bool { - if account == nil { - return false - } - if !account.IsSchedulable() { - return true - } - if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 { - return true - } - return false -} - type AccountWaitPlan struct { AccountID int64 MaxConcurrency int @@ -505,34 +181,10 @@ type ForwardResult struct { ImageSize string // 图片尺寸 "1K", "2K", "4K" } -// UpstreamFailoverError indicates an upstream error that should trigger account failover. -type UpstreamFailoverError struct { - StatusCode int - ResponseBody []byte // 上游响应体,用于错误透传规则匹配 - ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息 - ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true - RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 -} - func (e *UpstreamFailoverError) Error() string { return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) } -// TempUnscheduleRetryableError 对 RetryableOnSameAccount 类型的 failover 错误触发临时封禁。 -// 由 handler 层在同账号重试全部用尽、切换账号时调用。 -func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accountID int64, failoverErr *UpstreamFailoverError) { - if failoverErr == nil || !failoverErr.RetryableOnSameAccount { - return - } - // 根据状态码选择封禁策略 - switch failoverErr.StatusCode { - case http.StatusBadRequest: - tempUnscheduleGoogleConfigError(ctx, s.accountRepo, accountID, "[handler]") - case http.StatusBadGateway: - tempUnscheduleEmptyResponse(ctx, s.accountRepo, accountID, "[handler]") - } -} - // GatewayService handles API gateway operations type GatewayService struct { accountRepo AccountRepository @@ -651,8482 +303,2491 @@ func NewGatewayService( return svc } -// GenerateSessionHash 从预解析请求计算粘性会话 hash -func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { - if parsed == nil { - return "" - } - - // 1. 最高优先级:从 metadata.user_id 提取 session_xxx - if parsed.MetadataUserID != "" { - uid := ParseMetadataUserID(parsed.MetadataUserID) - if uid != nil && uid.SessionID != "" { - slog.Info("sticky.hash_source", - "source", "metadata_user_id", - "session_id", uid.SessionID, - "device_id", uid.DeviceID, - "is_new_format", uid.IsNewFormat, - ) - return uid.SessionID - } - slog.Info("sticky.hash_metadata_parse_failed", - "metadata_user_id", parsed.MetadataUserID, - "parsed_nil", uid == nil, - ) - } - - // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 - cacheableContent := s.extractCacheableContent(parsed) - if cacheableContent != "" { - hash := s.hashContent(cacheableContent) - slog.Info("sticky.hash_source", - "source", "cacheable_content", - "hash", hash, - ) - return hash - } - - // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 - var combined strings.Builder - // 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash - if parsed.SessionContext != nil { - _, _ = combined.WriteString(parsed.SessionContext.ClientIP) - _, _ = combined.WriteString(":") - _, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent)) - _, _ = combined.WriteString(":") - _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) - _, _ = combined.WriteString("|") - } - if parsed.System != nil { - systemText := s.extractTextFromSystem(parsed.System) - if systemText != "" { - _, _ = combined.WriteString(systemText) +// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. +func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 优先检查 context 中的强制平台(/antigravity 路由) + var platform string + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + platform = forcePlatform + } else if groupID != nil { + group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID) + if err != nil { + return nil, err } + groupID = resolvedGroupID + ctx = s.withGroupContext(ctx, group) + platform = group.Platform + } else { + // 无分组时只使用原生 anthropic 平台 + platform = PlatformAnthropic } - for _, msg := range parsed.Messages { - if m, ok := msg.(map[string]any); ok { - if content, exists := m["content"]; exists { - // Anthropic: messages[].content - if msgText := s.extractTextFromContent(content); msgText != "" { - _, _ = combined.WriteString(msgText) - } - } else if parts, ok := m["parts"].([]any); ok { - // Gemini: contents[].parts[].text - for _, part := range parts { - if partMap, ok := part.(map[string]any); ok { - if text, ok := partMap["text"].(string); ok { - _, _ = combined.WriteString(text) - } - } - } - } - } + + // Claude Code 限制可能已将 groupID 解析为 fallback group, + // 渠道限制预检查必须使用解析后的分组。 + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) } - if combined.Len() > 0 { - hash := s.hashContent(combined.String()) - slog.Info("sticky.hash_source", - "source", "message_content_fallback", - "hash", hash, - "content_len", combined.Len(), - ) - return hash + + // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) + // 注意:强制平台模式不走混合调度 + if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { + account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } - return "" + // antigravity 分组、强制平台模式或无分组使用单平台选择 + // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 + account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + if err != nil { + return nil, err + } + return s.hydrateSelectedAccount(ctx, account) } -// BindStickySession sets session -> account binding with standard TTL. -func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 || s.cache == nil { - return nil +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) + + for _, acc := range ordered { + result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, acc, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + continue + } + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) + } + selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) + if err != nil { + return nil, false, err + } + return selection, true, nil + } } - return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) + + return nil, false, nil } -// GetCachedSessionAccountID retrieves the account ID bound to a sticky session. -// Returns 0 if no binding exists or on error. -func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { - if sessionHash == "" || s.cache == nil { - return 0, nil +func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling } - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err != nil { - return 0, err + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, } - return accountID, nil } -// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) -// 返回最长匹配的会话信息(uuid, accountID) -func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { - if digestChain == "" || s.digestStore == nil { - return "", 0, "", false +func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context { + if !IsGroupContextValid(group) { + return ctx + } + if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) { + return ctx } - return s.digestStore.Find(groupID, prefixHash, digestChain) + return context.WithValue(ctx, ctxkey.Group, group) } -// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 -func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { - if digestChain == "" || s.digestStore == nil { - return nil +func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group { + if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID { + return group } - s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) return nil } -// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) -func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { - if digestChain == "" || s.digestStore == nil { - return "", 0, "", false +func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + if group := s.groupFromContext(ctx, groupID); group != nil { + return group, nil + } + group, err := s.groupRepo.GetByIDLite(ctx, groupID) + if err != nil { + return nil, fmt.Errorf("get group failed: %w", err) } - return s.digestStore.Find(groupID, prefixHash, digestChain) + return group, nil } -// SaveAnthropicSession 保存 Anthropic 会话 -func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { - if digestChain == "" || s.digestStore == nil { - return nil - } - s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) - return nil +func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { + return s.resolveGroupByID(ctx, groupID) } -func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { - if parsed == nil { - return "" +func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { + if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { + return nil } - - var builder strings.Builder - - // 检查 system 中的 cacheable 内容 - if system, ok := parsed.System.([]any); ok { - for _, part := range system { - if partMap, ok := part.(map[string]any); ok { - if cc, ok := partMap["cache_control"].(map[string]any); ok { - if cc["type"] == "ephemeral" { - if text, ok := partMap["text"].(string); ok { - _, _ = builder.WriteString(text) - } - } - } - } + group, err := s.resolveGroupByID(ctx, *groupID) + if err != nil || group == nil { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) } + return nil } - systemText := builder.String() - - // 检查 messages 中的 cacheable 内容 - for _, msg := range parsed.Messages { - if msgMap, ok := msg.(map[string]any); ok { - if msgContent, ok := msgMap["content"].([]any); ok { - for _, part := range msgContent { - if partMap, ok := part.(map[string]any); ok { - if cc, ok := partMap["cache_control"].(map[string]any); ok { - if cc["type"] == "ephemeral" { - return s.extractTextFromContent(msgMap["content"]) - } - } - } - } - } + // Preserve existing behavior: model routing only applies to anthropic groups. + if group.Platform != PlatformAnthropic { + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) } + return nil } - - return systemText + ids := group.GetRoutingAccountIDs(requestedModel) + if s.debugModelRoutingEnabled() { + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", + group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) + } + return ids } -func (s *GatewayService) extractTextFromSystem(system any) string { - switch v := system.(type) { - case string: - return v - case []any: - var texts []string - for _, part := range v { - if partMap, ok := part.(map[string]any); ok { - if text, ok := partMap["text"].(string); ok { - texts = append(texts, text) - } - } - } - return strings.Join(texts, "") +func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) { + if groupID == nil { + return nil, nil, nil } - return "" -} -func (s *GatewayService) extractTextFromContent(content any) string { - switch v := content.(type) { - case string: - return v - case []any: - var texts []string - for _, part := range v { - if partMap, ok := part.(map[string]any); ok { - if partMap["type"] == "text" { - if text, ok := partMap["text"].(string); ok { - texts = append(texts, text) - } - } - } + currentID := *groupID + visited := map[int64]struct{}{} + for { + if _, seen := visited[currentID]; seen { + return nil, nil, fmt.Errorf("fallback group cycle detected") } - return strings.Join(texts, "") - } - return "" -} + visited[currentID] = struct{}{} -func (s *GatewayService) hashContent(content string) string { - h := xxhash.Sum64String(content) - return strconv.FormatUint(h, 36) -} + group, err := s.resolveGroupByID(ctx, currentID) + if err != nil { + return nil, nil, err + } -type anthropicCacheControlPayload struct { - Type string `json:"type"` - TTL string `json:"ttl,omitempty"` -} + if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) { + return group, ¤tID, nil + } -type anthropicSystemTextBlockPayload struct { - Type string `json:"type"` - Text string `json:"text"` - CacheControl *anthropicCacheControlPayload `json:"cache_control,omitempty"` + if group.FallbackGroupID == nil { + return nil, nil, ErrClaudeCodeOnly + } + currentID = *group.FallbackGroupID + } } -type anthropicMetadataPayload struct { - UserID string `json:"user_id"` -} +// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 +// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: +// - 有降级分组:返回降级分组的 ID +// - 无降级分组:返回 ErrClaudeCodeOnly 错误 +func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) { + if groupID == nil { + return nil, groupID, nil + } -// replaceModelInBody 替换请求体中的model字段 -// 优先使用定点修改,尽量保持客户端原始字段顺序。 -func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { - return ReplaceModelInBody(body, newModel) -} + // 强制平台模式不检查 Claude Code 限制 + if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { + return nil, groupID, nil + } -type claudeOAuthNormalizeOptions struct { - injectMetadata bool - metadataUserID string - stripSystemCacheControl bool -} + group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID) + if err != nil { + return nil, nil, err + } -// sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). -// We intentionally avoid broad keyword replacement in system prompts to prevent -// accidentally changing user-provided instructions. -func sanitizeSystemText(text string) string { - if text == "" { - return text - } - // Some clients include a fixed OpenCode identity sentence. Anthropic may treat - // this as a non-Claude-Code fingerprint, so rewrite it to the canonical - // Claude Code banner before generic "OpenCode"/"opencode" replacements. - text = strings.ReplaceAll( - text, - "You are OpenCode, the best coding agent on the planet.", - strings.TrimSpace(claudeCodeSystemPrompt), - ) - return text + return group, resolvedID, nil } -func marshalAnthropicSystemTextBlock(text string, includeCacheControl bool) ([]byte, error) { - block := anthropicSystemTextBlockPayload{ - Type: "text", - Text: text, +func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) { + forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) + if hasForcePlatform && forcePlatform != "" { + return forcePlatform, true, nil } - if includeCacheControl { - block.CacheControl = &anthropicCacheControlPayload{ - Type: "ephemeral", - TTL: claude.DefaultCacheControlTTL, + if group != nil { + return group.Platform, false, nil + } + if groupID != nil { + group, err := s.resolveGroupByID(ctx, *groupID) + if err != nil { + return "", false, err } + return group.Platform, false, nil } - return json.Marshal(block) -} - -func marshalAnthropicMetadata(userID string) ([]byte, error) { - return json.Marshal(anthropicMetadataPayload{UserID: userID}) + return PlatformAnthropic, false, nil } -func buildJSONArrayRaw(items [][]byte) []byte { - if len(items) == 0 { - return []byte("[]") - } - - total := 2 - for _, item := range items { - total += len(item) +func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if s.schedulerSnapshot != nil { + accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err == nil { + slog.Debug("account_scheduling_list_snapshot", + "group_id", derefGroupID(groupID), + "platform", platform, + "use_mixed", useMixed, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + } + return accounts, useMixed, err } - total += len(items) - 1 - - buf := make([]byte, 0, total) - buf = append(buf, '[') - for i, item := range items { - if i > 0 { - buf = append(buf, ',') + useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform + if useMixed { + platforms := []string{platform, PlatformAntigravity} + var accounts []Account + var err error + if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) + } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) + return nil, useMixed, err } - buf = append(buf, item...) + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_mixed", + "group_id", derefGroupID(groupID), + "platform", platform, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil } - buf = append(buf, ']') - return buf -} -func setJSONValueBytes(body []byte, path string, value any) ([]byte, bool) { - next, err := sjson.SetBytes(body, path, value) - if err != nil { - return body, false + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) + // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform) } - return next, true -} - -func setJSONRawBytes(body []byte, path string, raw []byte) ([]byte, bool) { - next, err := sjson.SetRawBytes(body, path, raw) if err != nil { - return body, false + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) + return nil, useMixed, err + } + slog.Debug("account_scheduling_list_single", + "group_id", derefGroupID(groupID), + "platform", platform, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) } - return next, true + return accounts, useMixed, nil } -func deleteJSONPathBytes(body []byte, path string) ([]byte, bool) { - next, err := sjson.DeleteBytes(body, path) +// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 +// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, +// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 +func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool { + accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true) if err != nil { - return body, false + return false } - return next, true + return len(accounts) == 1 } -func normalizeClaudeOAuthSystemBody(body []byte, opts claudeOAuthNormalizeOptions) ([]byte, bool) { - sys := gjson.GetBytes(body, "system") - if !sys.Exists() { - return body, false +func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { + if account == nil { + return false } - - out := body - modified := false - - switch { - case sys.Type == gjson.String: - sanitized := sanitizeSystemText(sys.String()) - if sanitized != sys.String() { - if next, ok := setJSONValueBytes(out, "system", sanitized); ok { - out = next - modified = true - } - } - case sys.IsArray(): - index := 0 - sys.ForEach(func(_, item gjson.Result) bool { - if item.Get("type").String() == "text" { - textResult := item.Get("text") - if textResult.Exists() && textResult.Type == gjson.String { - text := textResult.String() - sanitized := sanitizeSystemText(text) - if sanitized != text { - if next, ok := setJSONValueBytes(out, fmt.Sprintf("system.%d.text", index), sanitized); ok { - out = next - modified = true - } - } - } - } - - if opts.stripSystemCacheControl && item.Get("cache_control").Exists() { - if next, ok := deleteJSONPathBytes(out, fmt.Sprintf("system.%d.cache_control", index)); ok { - out = next - modified = true - } - } - - index++ + if useMixed { + if account.Platform == platform { return true - }) + } + return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() } - - return out, modified + return account.Platform == platform } -func ensureClaudeOAuthMetadataUserID(body []byte, userID string) ([]byte, bool) { - if strings.TrimSpace(userID) == "" { - return body, false +func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { + if account == nil { + return false } + return account.IsSchedulable() +} - metadata := gjson.GetBytes(body, "metadata") - if !metadata.Exists() || metadata.Type == gjson.Null { - raw, err := marshalAnthropicMetadata(userID) - if err != nil { - return body, false - } - return setJSONRawBytes(body, "metadata", raw) +func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { + if account == nil { + return false } + return account.IsSchedulableForModelWithContext(ctx, requestedModel) +} - trimmedRaw := strings.TrimSpace(metadata.Raw) - if strings.HasPrefix(trimmedRaw, "{") { - existing := metadata.Get("user_id") - if existing.Exists() && existing.Type == gjson.String && existing.String() != "" { - return body, false +// isAccountInGroup checks if the account belongs to the specified group. +// When groupID is nil, returns true only for ungrouped accounts (no group assignments). +func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { + if account == nil { + return false + } + if groupID == nil { + // 无分组的 API Key 只能使用未分组的账号 + return len(account.AccountGroups) == 0 + } + for _, ag := range account.AccountGroups { + if ag.GroupID == *groupID { + return true } - return setJSONValueBytes(body, "metadata.user_id", userID) } + return false +} - raw, err := marshalAnthropicMetadata(userID) - if err != nil { - return body, false +func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil } - return setJSONRawBytes(body, "metadata", raw) + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } -func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { - if len(body) == 0 { - return body, modelID - } +type usageLogWindowStatsBatchProvider interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} - out := body - modified := false +type windowCostPrefetchContextKeyType struct{} - if next, changed := normalizeClaudeOAuthSystemBody(out, opts); changed { - out = next - modified = true - } +var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{} - rawModel := gjson.GetBytes(out, "model") - if rawModel.Exists() && rawModel.Type == gjson.String { - normalized := claude.NormalizeModelID(rawModel.String()) - if normalized != rawModel.String() { - if next, ok := setJSONValueBytes(out, "model", normalized); ok { - out = next - modified = true - } - modelID = normalized - } +func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) { + if ctx == nil || accountID <= 0 { + return 0, false + } + m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64) + if !ok || len(m) == 0 { + return 0, false } + v, exists := m[accountID] + return v, exists +} - // 确保 tools 字段存在(即使为空数组) - if !gjson.GetBytes(out, "tools").Exists() { - if next, ok := setJSONRawBytes(out, "tools", []byte("[]")); ok { - out = next - modified = true - } +func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context { + if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil { + return ctx } - if opts.injectMetadata && opts.metadataUserID != "" { - if next, changed := ensureClaudeOAuthMetadataUserID(out, opts.metadataUserID); changed { - out = next - modified = true + accountByID := make(map[int64]*Account) + accountIDs := make([]int64, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if account == nil || !account.IsAnthropicOAuthOrSetupToken() { + continue } + if account.GetWindowCostLimit() <= 0 { + continue + } + accountByID[account.ID] = account + accountIDs = append(accountIDs, account.ID) + } + if len(accountIDs) == 0 { + return ctx } - // temperature:真实 Claude Code CLI 总是发送 temperature(默认 1,客户端可覆盖)。 - // 之前的实现直接 delete 会导致 payload 缺字段,与真实 CLI 字节级不一致。 - // 策略:客户端传了什么就透传;没传则补默认 1。 - if !gjson.GetBytes(out, "temperature").Exists() { - if next, ok := setJSONValueBytes(out, "temperature", 1); ok { - out = next - modified = true + costs := make(map[int64]float64, len(accountIDs)) + cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs) + if err == nil { + for accountID, cost := range cacheValues { + costs[accountID] = cost } + windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues))) + } else { + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err) + } + cacheMissCount := len(accountIDs) - len(costs) + if cacheMissCount < 0 { + cacheMissCount = 0 } + windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount)) - // max_tokens:真实 CLI 的默认值是 128000。缺失时补齐以对齐指纹。 - if !gjson.GetBytes(out, "max_tokens").Exists() { - if next, ok := setJSONValueBytes(out, "max_tokens", 128000); ok { - out = next - modified = true + missingByStart := make(map[int64][]int64) + startTimes := make(map[int64]time.Time) + for _, accountID := range accountIDs { + if _, ok := costs[accountID]; ok { + continue + } + account := accountByID[accountID] + if account == nil { + continue } + startTime := account.GetCurrentWindowStartTime() + startKey := startTime.Unix() + missingByStart[startKey] = append(missingByStart[startKey], accountID) + startTimes[startKey] = startTime + } + if len(missingByStart) == 0 { + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) } - // context_management:thinking.type 为 enabled/adaptive 时,真实 CLI 会自动 - // 附带 {"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}。 - // 客户端显式传了就透传;否则按 CLI 行为补齐。 - if !gjson.GetBytes(out, "context_management").Exists() { - thinkingType := gjson.GetBytes(out, "thinking.type").String() - if thinkingType == "enabled" || thinkingType == "adaptive" { - const cmDefault = `{"edits":[{"type":"clear_thinking_20251015","keep":"all"}]}` - if next, ok := setJSONRawBytes(out, "context_management", []byte(cmDefault)); ok { - out = next - modified = true + batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider) + for startKey, ids := range missingByStart { + startTime := startTimes[startKey] + + if hasBatch { + windowCostPrefetchBatchSQLTotal.Add(1) + queryStart := time.Now() + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime) + if err == nil { + slog.Debug("window_cost_batch_query_ok", + "accounts", len(ids), + "window_start", startTime.Format(time.RFC3339), + "duration_ms", time.Since(queryStart).Milliseconds()) + for _, accountID := range ids { + stats := statsByAccount[accountID] + cost := 0.0 + if stats != nil { + cost = stats.StandardCost + } + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + continue } + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err) } - } - // tool_choice:与 Parrot 对齐,不再无条件删除。 - // - 客户端传了 {"type":"tool","name":"X"} → 保留结构,name 由 - // applyToolNameRewriteToBody 同步映射为假名 - // - 其他形态(auto/any/none)原样透传 - // 如果 body 里完全没有 tools(空数组),tool_choice 没意义时才删除 - if !gjson.GetBytes(out, "tools").IsArray() || len(gjson.GetBytes(out, "tools").Array()) == 0 { - if gjson.GetBytes(out, "tool_choice").Exists() { - if next, ok := deleteJSONPathBytes(out, "tool_choice"); ok { - out = next - modified = true + // 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。 + windowCostPrefetchFallbackTotal.Add(int64(len(ids))) + for _, accountID := range ids { + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) + if err != nil { + windowCostPrefetchErrorTotal.Add(1) + continue } + cost := stats.StandardCost + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) } } - if !modified { - return body, modelID - } + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) +} - return out, modelID +// isAccountSchedulableForQuota 检查账号是否在配额限制内 +// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号 +func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { + if !account.IsAPIKeyOrBedrock() { + return true + } + return !account.IsQuotaExceeded() } -func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { - if parsed == nil || account == nil { - return "" +// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +// 返回 true 表示可调度,false 表示不可调度 +func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool { + // 只检查 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return true } - if parsed.MetadataUserID != "" { - return "" + + limit := account.GetWindowCostLimit() + if limit <= 0 { + return true // 未启用窗口费用限制 } - userID := strings.TrimSpace(account.GetClaudeUserID()) - if userID == "" && fp != nil { - userID = fp.ClientID + // 尝试从缓存获取窗口费用 + var currentCost float64 + if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok { + currentCost = cost + goto checkSchedulability } - if userID == "" { - // Fall back to a random, well-formed client id so we can still satisfy - // Claude Code OAuth requirements when account metadata is incomplete. - userID = generateClientID() + if s.sessionLimitCache != nil { + if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { + currentCost = cost + goto checkSchedulability + } } - sessionHash := s.GenerateSessionHash(parsed) - sessionID := uuid.NewString() - if sessionHash != "" { - seed := fmt.Sprintf("%d::%s", account.ID, sessionHash) - sessionID = generateSessionUUID(seed) - } + // 缓存未命中,从数据库查询 + { + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := account.GetCurrentWindowStartTime() - // 根据指纹 UA 版本选择输出格式 - var uaVersion string - if fp != nil { - uaVersion = ExtractCLIVersion(fp.UserAgent) - } - accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) - return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) -} + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) + if err != nil { + // 失败开放:查询失败时允许调度 + return true + } -// applyClaudeCodeOAuthMimicryToBody 将"非 Claude Code 客户端 + Claude OAuth 账号" -// 路径上原本只在 /v1/messages 里做的完整伪装应用到任意 body 上。 -// -// 这是 /v1/messages 主路径上 rewriteSystemForNonClaudeCode + -// normalizeClaudeOAuthRequestBody 流程的通用版,供 OpenAI 协议兼容层 -// (ForwardAsChatCompletions / ForwardAsResponses) 复用。 -// -// 未抽离之前,OpenAI 协议兼容层仅做 injectClaudeCodePrompt(前置追加), -// 而仓内 /v1/messages 路径自己的注释明确说过"仅前置追加无法通过 Anthropic -// 第三方检测";那条注释就是本函数存在的根因。 -// -// 参数: -// - ctx / c:用于读取指纹和 gateway settings;c 可为 nil(如 count_tokens)。 -// - account:必须是 OAuth 账号,且调用方已判断不是 Claude Code 客户端。 -// - body:已经 marshal 成 Anthropic /v1/messages 格式的请求体。 -// - systemRaw:body 中原始 system 字段(用于判断是否需要 rewrite)。 -// - model:最终会发给上游的模型 ID(用于 haiku 旁路 + metadata 版本选择)。 -// -// 返回:改写后的 body。即使中间任何一步失败,也会退化成原 body(不会 panic)。 -func (s *GatewayService) applyClaudeCodeOAuthMimicryToBody( - ctx context.Context, - c *gin.Context, - account *Account, - body []byte, - systemRaw any, - model string, -) []byte { - if account == nil || !account.IsOAuth() || len(body) == 0 { - return body - } + // 使用标准费用(不含账号倍率) + currentCost = stats.StandardCost - systemRewritten := false - if !strings.Contains(strings.ToLower(model), "haiku") { - body = rewriteSystemForNonClaudeCode(body, systemRaw) - systemRewritten = true + // 设置缓存(忽略错误) + if s.sessionLimitCache != nil { + _ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost) + } } - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} +checkSchedulability: + schedulability := account.CheckWindowCostSchedulability(currentCost) - if s.identityService != nil && c != nil && c.Request != nil { - if fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header); err == nil && fp != nil { - mimicMPT := false - if s.settingService != nil { - _, mimicMPT, _ = s.settingService.GetGatewayForwardingSettings(ctx) - } - if !mimicMPT { - if uid := s.buildOAuthMetadataUserIDFromBody(ctx, account, fp, body); uid != "" { - normalizeOpts.injectMetadata = true - normalizeOpts.metadataUserID = uid - } - } - } + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false } + return true +} - body, _ = normalizeClaudeOAuthRequestBody(body, model, normalizeOpts) +// rpmPrefetchContextKey is the context key for prefetched RPM counts. +type rpmPrefetchContextKeyType struct{} - // Phase D+E+F: messages cache 策略 + 工具名混淆 + tools[-1] 断点 - // 对齐 Parrot transform_request 里剩余的字段级改写。顺序有语义约束: - // 1) messages cache:仅在配置开启时清除客户端断点并注入代理断点 - // 2) tool rewrite:最后改 tools[*].name / tool_choice.name 并在 tools[-1] - // 上打断点;mapping 存入 gin.Context 供响应侧 bytes.Replace 还原。 - body = s.rewriteMessageCacheControlIfEnabled(ctx, body) +var rpmPrefetchContextKey = rpmPrefetchContextKeyType{} - if rw := buildToolNameRewriteFromBody(body); rw != nil { - body = applyToolNameRewriteToBody(body, rw) - if c != nil { - c.Set(toolNameRewriteKey, rw) - } - } else { - body = applyToolsLastCacheBreakpoint(body) +func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) { + if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok { + count, found := v[accountID] + return count, found } - - return body + return 0, false } -// buildOAuthMetadataUserIDFromBody 是 buildOAuthMetadataUserID 的变体, -// 适用于调用方手上没有 ParsedRequest 的场景(如 OpenAI 协议兼容层)。 -// -// 与 buildOAuthMetadataUserID 的唯一区别: -// - session hash 从 body 本体按同样规则重算,而不是读取 ParsedRequest 缓存值。 -// - 如果 body 里已经存在 metadata.user_id,则返回空(由 ensureClaudeOAuthMetadataUserID -// 自行决定是否覆盖)。 -func (s *GatewayService) buildOAuthMetadataUserIDFromBody( - ctx context.Context, - account *Account, - fp *Fingerprint, - body []byte, -) string { - _ = ctx - if account == nil { - return "" - } - if existing := gjson.GetBytes(body, "metadata.user_id").String(); existing != "" { - return "" +// withRPMPrefetch 批量预取所有候选账号的 RPM 计数 +func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context { + if s.rpmCache == nil { + return ctx } - userID := strings.TrimSpace(account.GetClaudeUserID()) - if userID == "" && fp != nil { - userID = fp.ClientID - } - if userID == "" { - userID = generateClientID() + var ids []int64 + for i := range accounts { + if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 { + ids = append(ids, accounts[i].ID) + } } - - sessionID := uuid.NewString() - if hash := hashBodyForSessionSeed(body); hash != "" { - sessionID = generateSessionUUID(fmt.Sprintf("%d::%s", account.ID, hash)) + if len(ids) == 0 { + return ctx } - var uaVersion string - if fp != nil { - uaVersion = ExtractCLIVersion(fp.UserAgent) + counts, err := s.rpmCache.GetRPMBatch(ctx, ids) + if err != nil { + return ctx // 失败开放 } - accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) - return FormatMetadataUserID(userID, accountUUID, sessionID, uaVersion) + return context.WithValue(ctx, rpmPrefetchContextKey, counts) } -// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。 -// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。 -func hashBodyForSessionSeed(body []byte) string { - if len(body) == 0 { - return "" +// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool { + if !account.IsAnthropicOAuthOrSetupToken() { + return true + } + baseRPM := account.GetBaseRPM() + if baseRPM <= 0 { + return true } - sum := sha256.Sum256(body) - return fmt.Sprintf("%x", sum[:16]) -} -// GenerateSessionUUID creates a deterministic UUID4 from a seed string. -func GenerateSessionUUID(seed string) string { - return generateSessionUUID(seed) -} + // 尝试从预取缓存获取 + var currentRPM int + if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok { + currentRPM = count + } else if s.rpmCache != nil { + if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil { + currentRPM = count + } + // 失败开放:GetRPM 错误时允许调度 + } -func generateSessionUUID(seed string) string { - if seed == "" { - return uuid.NewString() + schedulability := account.CheckRPMSchedulability(currentRPM) + switch schedulability { + case WindowCostSchedulable: + return true + case WindowCostStickyOnly: + return isSticky + case WindowCostNotSchedulable: + return false } - hash := sha256.Sum256([]byte(seed)) - bytes := hash[:16] - bytes[6] = (bytes[6] & 0x0f) | 0x40 - bytes[8] = (bytes[8] & 0x3f) | 0x80 - return fmt.Sprintf("%x-%x-%x-%x-%x", - bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) + return true } -// SelectAccount 选择账号(粘性会话+优先级) -func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { - return s.SelectAccountForModel(ctx, groupID, sessionHash, "") +// IncrementAccountRPM increments the RPM counter for the given account. +// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, +// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit +// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 +func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { + if s.rpmCache == nil { + return nil + } + _, err := s.rpmCache.IncrementRPM(ctx, accountID) + return err } -// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) -func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { - return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) -} - -// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. -func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - // 优先检查 context 中的强制平台(/antigravity 路由) - var platform string - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform != "" { - platform = forcePlatform - } else if groupID != nil { - group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID) - if err != nil { - return nil, err - } - groupID = resolvedGroupID - ctx = s.withGroupContext(ctx, group) - platform = group.Platform - } else { - // 无分组时只使用原生 anthropic 平台 - platform = PlatformAnthropic +// checkAndRegisterSession 检查并注册会话,用于会话数量限制 +// 仅适用于 Anthropic OAuth/SetupToken 账号 +// sessionID: 会话标识符(使用粘性会话的 hash) +// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) +func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool { + // 只检查 Anthropic OAuth/SetupToken 账号 + if !account.IsAnthropicOAuthOrSetupToken() { + return true } - // Claude Code 限制可能已将 groupID 解析为 fallback group, - // 渠道限制预检查必须使用解析后的分组。 - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + maxSessions := account.GetMaxSessions() + if maxSessions <= 0 || sessionID == "" { + return true // 未启用会话限制或无会话ID } - // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) - // 注意:强制平台模式不走混合调度 - if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - if err != nil { - return nil, err - } - return s.hydrateSelectedAccount(ctx, account) + if s.sessionLimitCache == nil { + return true // 缓存不可用时允许通过 } - // antigravity 分组、强制平台模式或无分组使用单平台选择 - // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 - account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) + idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute + + allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout) if err != nil { - return nil, err + // 失败开放:缓存错误时允许通过 + return true } - return s.hydrateSelectedAccount(ctx, account) + return allowed } -// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。 -// metadataUserID: 用于客户端亲和调度,从中提取客户端 ID -// sub2apiUserID: 系统用户 ID,用于二维亲和调度 -func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { - // 调试日志:记录调度入口参数 - excludedIDsList := make([]int64, 0, len(excludedIDs)) - for id := range excludedIDs { - excludedIDsList = append(excludedIDsList, id) - } - slog.Debug("account_scheduling_starting", - "group_id", derefGroupID(groupID), - "model", requestedModel, - "session", shortSessionHash(sessionHash), - "excluded_ids", excludedIDsList) - - cfg := s.schedulingConfig() +func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + if s.schedulerSnapshot != nil { + return s.schedulerSnapshot.GetAccount(ctx, accountID) + } + return s.accountRepo.GetByID(ctx, accountID) +} - // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) - group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) +func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) if err != nil { return nil, err } - ctx = s.withGroupContext(ctx, group) - - // Claude Code 限制可能已将 groupID 解析为 fallback group, - // 渠道限制预检查必须使用解析后的分组。 - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + if hydrated == nil { + return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) } + return hydrated, nil +} - var stickyAccountID int64 - var stickySource string - if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { - stickyAccountID = prefetch - stickySource = "prefetch" - } else if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { - stickyAccountID = accountID - stickySource = "cache" - } +func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} - // [DEBUG-STICKY] 调度器入口日志 - slog.Info("sticky.scheduler_entry", - "group_id", derefGroupID(groupID), - "session_hash", shortSessionHash(sessionHash), - "sticky_account_id", stickyAccountID, - "sticky_source", stickySource, - "model", requestedModel, - "load_batch", cfg.LoadBatchEnabled, - "has_concurrency_svc", s.concurrencyService != nil, - "excluded_count", len(excludedIDs), - ) - - if s.debugModelRoutingEnabled() && requestedModel != "" { - groupPlatform := "" - if group != nil { - groupPlatform = group.Platform - } - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", - derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) +// filterByMinPriority 过滤出优先级最小的账号集合 +func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts } - - if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - // 复制排除列表,用于会话限制拒绝时的重试 - localExcluded := make(map[int64]struct{}) - for k, v := range excludedIDs { - localExcluded[k] = v + minPriority := accounts[0].account.Priority + for _, acc := range accounts[1:] { + if acc.account.Priority < minPriority { + minPriority = acc.account.Priority } - - for { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded) - if err != nil { - return nil, err - } - - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - // 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符) - if !s.checkAndRegisterSession(ctx, account, sessionHash) { - result.ReleaseFunc() // 释放槽位 - localExcluded[account.ID] = struct{}{} // 排除此账号 - continue // 重新选择 - } - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) - } - - // 对于等待计划的情况,也需要先检查会话限制 - if !s.checkAndRegisterSession(ctx, account, sessionHash) { - localExcluded[account.ID] = struct{}{} - continue - } - - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) - } - } - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) + } + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.account.Priority == minPriority { + result = append(result, acc) } } + return result +} - platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) - if err != nil { - return nil, err - } - preferOAuth := platform == PlatformGemini - if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) +// filterByMinLoadRate 过滤出负载率最低的账号集合 +func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts } - - accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err != nil { - return nil, err + minLoadRate := accounts[0].loadInfo.LoadRate + for _, acc := range accounts[1:] { + if acc.loadInfo.LoadRate < minLoadRate { + minLoadRate = acc.loadInfo.LoadRate + } } - if len(accounts) == 0 { - return nil, ErrNoAvailableAccounts + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.loadInfo.LoadRate == minLoadRate { + result = append(result, acc) + } } - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) + return result +} - // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) - accountByID := make(map[int64]*Account, len(accounts)) - for i := range accounts { - accountByID[accounts[i].ID] = &accounts[i] +// selectByLRU 从集合中选择最久未用的账号 +// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 +func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { + if len(accounts) == 0 { + return nil } - isExcluded := func(accountID int64) bool { - if excludedIDs == nil { - return false - } - _, excluded := excludedIDs[accountID] - return excluded + if len(accounts) == 1 { + return &accounts[0] } - // 获取模型路由配置(仅 anthropic 平台) - var routingAccountIDs []int64 - if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { - routingAccountIDs = group.GetRoutingAccountIDs(requestedModel) - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", - group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID) - if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 { - keys := make([]string, 0, len(group.ModelRouting)) - for k := range group.ModelRouting { - keys = append(keys, k) - } - sort.Strings(keys) - const maxKeys = 20 - if len(keys) > maxKeys { - keys = keys[:maxKeys] - } - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) - } + // 1. 找到最小的 LastUsedAt(nil 被视为最小) + var minTime *time.Time + hasNil := false + for _, acc := range accounts { + if acc.account.LastUsedAt == nil { + hasNil = true + break + } + if minTime == nil || acc.account.LastUsedAt.Before(*minTime) { + minTime = acc.account.LastUsedAt } } - // ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============ - if len(routingAccountIDs) > 0 && s.concurrencyService != nil { - // 1. 过滤出路由列表中可调度的账号 - var routingCandidates []*Account - var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int - var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID - for _, routingAccountID := range routingAccountIDs { - if isExcluded(routingAccountID) { - filteredExcluded++ - continue - } - account, ok := accountByID[routingAccountID] - if !ok || !s.isAccountSchedulableForSelection(account) { - if !ok { - filteredMissing++ - } else { - filteredUnsched++ - } - continue - } - if !s.isAccountAllowedForPlatform(account, platform, useMixed) { - filteredPlatform++ - continue - } - if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) { - filteredModelMapping++ - continue - } - if !s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { - filteredModelScope++ - modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) - continue - } - // 配额检查 - if !s.isAccountSchedulableForQuota(account) { - continue - } - // 窗口费用检查(非粘性会话路径) - if !s.isAccountSchedulableForWindowCost(ctx, account, false) { - filteredWindowCost++ - continue + // 2. 收集所有具有最小 LastUsedAt 的账号索引 + var candidateIdxs []int + for i, acc := range accounts { + if hasNil { + if acc.account.LastUsedAt == nil { + candidateIdxs = append(candidateIdxs, i) } - // RPM 检查(非粘性会话路径) - if !s.isAccountSchedulableForRPM(ctx, account, false) { - continue + } else { + if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) { + candidateIdxs = append(candidateIdxs, i) } - routingCandidates = append(routingCandidates, account) } + } - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", - derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), - filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) - if len(modelScopeSkippedIDs) > 0 { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", - derefGroupID(groupID), requestedModel, modelScopeSkippedIDs) + // 3. 如果只有一个候选,直接返回 + if len(candidateIdxs) == 1 { + return &accounts[candidateIdxs[0]] + } + + // 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型 + if preferOAuth { + var oauthIdxs []int + for _, idx := range candidateIdxs { + if accounts[idx].account.Type == AccountTypeOAuth { + oauthIdxs = append(oauthIdxs, idx) } } + if len(oauthIdxs) > 0 { + candidateIdxs = oauthIdxs + } + } - if len(routingCandidates) > 0 { - // 1.5. 在路由账号范围内检查粘性会话 - if sessionHash != "" && stickyAccountID > 0 { - slog.Debug("sticky.layer1_5_checking", - "sticky_account_id", stickyAccountID, - "in_routing_list", containsInt64(routingAccountIDs, stickyAccountID), - "is_excluded", isExcluded(stickyAccountID), - "in_account_map", func() bool { _, ok := accountByID[stickyAccountID]; return ok }(), - "session", shortSessionHash(sessionHash), - ) - if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { - // 粘性账号在路由列表中,优先使用 - if stickyAccount, ok := accountByID[stickyAccountID]; ok { - var stickyCacheMissReason string - - gatePass := s.isAccountSchedulableForSelection(stickyAccount) && - s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && - (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && - s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && - s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) - - rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) - - if rpmPass { // 粘性会话窗口费用+RPM 检查 - result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - result.ReleaseFunc() // 释放槽位 - stickyCacheMissReason = "session_limit" - // 继续到负载感知选择 - } else { - slog.Debug("sticky.layer1_5_hit", - "account_id", stickyAccountID, - "session", shortSessionHash(sessionHash), - "result", "slot_acquired", - ) - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) - } - return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) - } - } - - if stickyCacheMissReason == "" { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - stickyCacheMissReason = "session_limit" - // 会话限制已满,继续到负载感知选择 - } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } - } else { - stickyCacheMissReason = "wait_queue_full" - } - } - // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 - } else if !gatePass { - stickyCacheMissReason = "gate_check" - } else { - stickyCacheMissReason = "rpm_red" - } - - // 记录粘性缓存未命中的结构化日志 - if stickyCacheMissReason != "" { - baseRPM := stickyAccount.GetBaseRPM() - var currentRPM int - if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { - currentRPM = count - } - logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", - stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) - } - } else { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", - stickyAccountID, shortSessionHash(sessionHash)) - } - } - } + // 5. 随机选择一个 + selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))] + return &accounts[selectedIdx] +} - // 2. 批量获取负载信息 - routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates)) - for _, acc := range routingCandidates { - routingLoads = append(routingLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.EffectiveLoadFactor(), - }) - } - routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) - - // 3. 按负载感知排序 - var routingAvailable []accountWithLoad - for _, acc := range routingCandidates { - loadInfo := routingLoadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo}) - } - } +func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + return isBetterAccountCandidate(accounts[i], accounts[j], preferOAuth, "") + }) + shuffleWithinPriorityAndLastUsed(accounts, preferOAuth) +} - if len(routingAvailable) > 0 { - // 排序:优先级 > 负载率 > 最后使用时间 - sort.SliceStable(routingAvailable, func(i, j int) bool { - a, b := routingAvailable[i], routingAvailable[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - shuffleWithinSortGroups(routingAvailable) - - // 4. 尝试获取槽位 - for _, item := range routingAvailable { - result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - continue - } - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) - } - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) - } - return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) - } - } - - // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的) - // 遍历找到第一个满足会话限制的账号 - for _, item := range routingAvailable { - if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { - continue // 会话限制已满,尝试下一个 - } - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) - } - return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ - AccountID: item.account.ID, - MaxConcurrency: item.account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) - } - // 所有路由账号会话限制都已满,继续到 Layer 2 回退 - } - // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 - logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) - } - } - - // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ - if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { - accountID := stickyAccountID - if accountID > 0 && !isExcluded(accountID) { - account, ok := accountByID[accountID] - if ok { - // 检查账户是否需要清理粘性会话绑定 - clearSticky := shouldClearStickySession(account, requestedModel) - if clearSticky { - slog.Debug("sticky.layer1_5_no_routing_clear", - "account_id", accountID, - "reason", "should_clear_sticky_session", - "session", shortSessionHash(sessionHash), - ) - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - } - - // 注意:不再检查 isAccountInGroup,因为 accountByID 已经从按分组过滤的 - // accounts 列表构建,账号一定在分组内。而 scheduler snapshot 缓存 - // 反序列化后 AccountGroups 字段为空,导致 isAccountInGroup 永远返回 false。 - platformOK := s.isAccountAllowedForPlatform(account, platform, useMixed) - modelSupported := requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) - modelSchedulable := s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) - quotaOK := s.isAccountSchedulableForQuota(account) - windowCostOK := s.isAccountSchedulableForWindowCost(ctx, account, true) - rpmOK := s.isAccountSchedulableForRPM(ctx, account, true) - schedulable := s.isAccountSchedulableForSelection(account) - - slog.Debug("sticky.layer1_5_no_routing_checks", - "account_id", accountID, - "session", shortSessionHash(sessionHash), - "clear_sticky", clearSticky, - "schedulable", schedulable, - "platform_ok", platformOK, - "model_supported", modelSupported, - "model_schedulable", modelSchedulable, - "quota_ok", quotaOK, - "window_cost_ok", windowCostOK, - "rpm_ok", rpmOK, - ) - - if !clearSticky && platformOK && modelSupported && modelSchedulable && quotaOK && windowCostOK && rpmOK && schedulable { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续到 Layer 2 - slog.Debug("sticky.layer1_5_no_routing_miss", - "account_id", accountID, - "reason", "session_limit", - "session", shortSessionHash(sessionHash), - ) - } else { - slog.Debug("sticky.layer1_5_no_routing_hit", - "account_id", accountID, - "session", shortSessionHash(sessionHash), - "result", "slot_acquired", - ) - if s.cache != nil { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) - } - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) - } - } else { - slog.Debug("sticky.layer1_5_no_routing_slot_busy", - "account_id", accountID, - "session", shortSessionHash(sessionHash), - ) - } - - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, account, sessionHash) { - // 会话限制已满,继续到 Layer 2 - } else { - slog.Debug("sticky.layer1_5_no_routing_hit", - "account_id", accountID, - "session", shortSessionHash(sessionHash), - "result", "wait_plan", - ) - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) - } - } - } else if !clearSticky { - slog.Debug("sticky.layer1_5_no_routing_miss", - "account_id", accountID, - "reason", "gate_check_failed", - "session", shortSessionHash(sessionHash), - ) - } - } else { - slog.Debug("sticky.layer1_5_no_routing_miss", - "account_id", accountID, - "reason", "account_not_in_map", - "session", shortSessionHash(sessionHash), - ) - } - } - } else if len(routingAccountIDs) == 0 && sessionHash != "" { - slog.Debug("sticky.layer1_5_no_routing_skip", - "sticky_account_id", stickyAccountID, - "is_excluded", func() bool { return stickyAccountID > 0 && isExcluded(stickyAccountID) }(), - "session", shortSessionHash(sessionHash), - "reason", func() string { - if stickyAccountID == 0 { - return "no_sticky_binding" - } - return "sticky_account_excluded" - }(), - ) +// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 +// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。 +func shuffleWithinSortGroups(accounts []accountWithLoad) { + if len(accounts) <= 1 { + return } - - // ============ Layer 2: 负载感知选择 ============ - slog.Debug("sticky.layer2_fallback", - "session", shortSessionHash(sessionHash), - "sticky_account_id", stickyAccountID, - "reason", "sticky_not_used_falling_back_to_load_balance", - "total_accounts", len(accounts), - ) - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] - if isExcluded(acc.ID) { - continue - } - // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); - // re-check schedulability here so recently rate-limited/overloaded accounts - // are not selected again before the bucket is rebuilt. - if !s.isAccountSchedulableForSelection(acc) { - continue - } - if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { - continue - } - // 配额检查 - if !s.isAccountSchedulableForQuota(acc) { - continue - } - // 窗口费用检查(非粘性会话路径) - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) { + j++ } - // RPM 检查(非粘性会话路径) - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue + if j-i > 1 { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } - candidates = append(candidates, acc) + i = j } +} - if len(candidates) == 0 { - return nil, ErrNoAvailableAccounts +// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组 +func sameAccountWithLoadGroup(a, b accountWithLoad) bool { + if a.account.Priority != b.account.Priority { + return false } - - accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) - for _, acc := range candidates { - accountLoads = append(accountLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.EffectiveLoadFactor(), - }) + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return false } + return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt) +} - loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) - if err != nil { - if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil { - return nil, legacyErr - } else if ok { - return result, nil - } - } else { - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) - } +// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 +// +// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 +// 因此这里采用"组内分区 + 分区内 shuffle"的方式: +// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; +// - 再分别在各段内随机打散,避免热点。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { + if len(accounts) <= 1 { + return + } + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) { + j++ } - - // 分层过滤选择:优先级 → 负载率 → LRU - for len(available) > 0 { - // 1. 取优先级最小的集合 - candidates := filterByMinPriority(available) - // 2. 取负载率最低的集合 - candidates = filterByMinLoadRate(candidates) - // 3. LRU 选择最久未用的账号 - selected := selectByLRU(candidates, preferOAuth) - if selected == nil { - break - } - - result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - } else { - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) + if j-i > 1 { + if preferOAuth { + oauth := make([]*Account, 0, j-i) + others := make([]*Account, 0, j-i) + for _, acc := range accounts[i:j] { + if acc.Type == AccountTypeOAuth { + oauth = append(oauth, acc) + } else { + others = append(others, acc) } - return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) } - } - - // 移除已尝试的账号,重新进行分层过滤 - selectedID := selected.account.ID - newAvailable := make([]accountWithLoad, 0, len(available)-1) - for _, acc := range available { - if acc.account.ID != selectedID { - newAvailable = append(newAvailable, acc) + if len(oauth) > 1 { + mathrand.Shuffle(len(oauth), func(a, b int) { oauth[a], oauth[b] = oauth[b], oauth[a] }) + } + if len(others) > 1 { + mathrand.Shuffle(len(others), func(a, b int) { others[a], others[b] = others[b], others[a] }) } + copy(accounts[i:], oauth) + copy(accounts[i+len(oauth):], others) + } else { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } - available = newAvailable - } - } - - // ============ Layer 3: 兜底排队 ============ - s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode) - for _, acc := range candidates { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, acc, sessionHash) { - continue // 会话限制已满,尝试下一个账号 } - return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) + i = j } - return nil, ErrNoAvailableAccounts } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { - ordered := append([]*Account(nil), candidates...) - sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) - - for _, acc := range ordered { - result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, acc, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - continue - } - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) - } - selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) - if err != nil { - return nil, false, err - } - return selection, true, nil - } +// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt) +func sameAccountGroup(a, b *Account) bool { + if a.Priority != b.Priority { + return false } - - return nil, false, nil + return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt) } -func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { - if s.cfg != nil { - return s.cfg.Gateway.Scheduling - } - return config.GatewaySchedulingConfig{ - StickySessionMaxWaiting: 3, - StickySessionWaitTimeout: 45 * time.Second, - FallbackWaitTimeout: 30 * time.Second, - FallbackMaxWaiting: 100, - LoadBatchEnabled: true, - SlotCleanupInterval: 30 * time.Second, +// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒) +func sameLastUsedAt(a, b *time.Time) bool { + switch { + case a == nil && b == nil: + return true + case a == nil || b == nil: + return false + default: + return a.Unix() == b.Unix() } } -func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context { - if !IsGroupContextValid(group) { - return ctx - } - if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) { - return ctx +// sortCandidatesForFallback 根据配置选择排序策略 +// mode: "last_used"(按最后使用时间) 或 "random"(随机) +func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { + if mode == "random" { + // 先按优先级排序,然后在同优先级内随机打乱 + sortAccountsByPriorityOnly(accounts, preferOAuth) + shuffleWithinPriority(accounts) + } else { + // 默认按最后使用时间排序 + sortAccountsByPriorityAndLastUsed(accounts, preferOAuth) } - return context.WithValue(ctx, ctxkey.Group, group) } -func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group { - if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID { - return group - } - return nil +// sortAccountsByPriorityOnly 仅按优先级排序 +func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + }) } -func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { - if group := s.groupFromContext(ctx, groupID); group != nil { - return group, nil +// shuffleWithinPriority 在同优先级内随机打乱顺序 +func shuffleWithinPriority(accounts []*Account) { + if len(accounts) <= 1 { + return } - group, err := s.groupRepo.GetByIDLite(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group failed: %w", err) + r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) + start := 0 + for start < len(accounts) { + priority := accounts[start].Priority + end := start + 1 + for end < len(accounts) && accounts[end].Priority == priority { + end++ + } + // 对 [start, end) 范围内的账户随机打乱 + if end-start > 1 { + r.Shuffle(end-start, func(i, j int) { + accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i] + }) + } + start = end } - return group, nil } -func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { - return s.resolveGroupByID(ctx, groupID) +type selectionFailureStats struct { + Total int + Eligible int + Excluded int + Unschedulable int + PlatformFiltered int + ModelUnsupported int + ModelRateLimited int + SamplePlatformIDs []int64 + SampleMappingIDs []int64 + SampleRateLimitIDs []string } -func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { - if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { - return nil - } - group, err := s.resolveGroupByID(ctx, *groupID) - if err != nil || group == nil { - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) - } - return nil +type selectionFailureDiagnosis struct { + Category string + Detail string +} + +func (s *GatewayService) logDetailedSelectionFailure( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + platform string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling) + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v", + derefGroupID(groupID), + requestedModel, + platform, + shortSessionHash(sessionHash), + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + stats.SamplePlatformIDs, + stats.SampleMappingIDs, + stats.SampleRateLimitIDs, + ) + return stats +} + +func (s *GatewayService) collectSelectionFailureStats( + ctx context.Context, + accounts []Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := selectionFailureStats{ + Total: len(accounts), } - // Preserve existing behavior: model routing only applies to anthropic groups. - if group.Platform != PlatformAnthropic { - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) + + for i := range accounts { + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling) + switch diagnosis.Category { + case "excluded": + stats.Excluded++ + case "unschedulable": + stats.Unschedulable++ + case "platform_filtered": + stats.PlatformFiltered++ + stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID) + case "model_unsupported": + stats.ModelUnsupported++ + stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID) + case "model_rate_limited": + stats.ModelRateLimited++ + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining) + default: + stats.Eligible++ } - return nil - } - ids := group.GetRoutingAccountIDs(requestedModel) - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", - group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) } - return ids + + return stats } -func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) { - if groupID == nil { - return nil, nil, nil +func (s *GatewayService) diagnoseSelectionFailure( + ctx context.Context, + acc *Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureDiagnosis { + if acc == nil { + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"} } - - currentID := *groupID - visited := map[int64]struct{}{} - for { - if _, seen := visited[currentID]; seen { - return nil, nil, fmt.Errorf("fallback group cycle detected") + if _, excluded := excludedIDs[acc.ID]; excluded { + return selectionFailureDiagnosis{Category: "excluded"} + } + if !s.isAccountSchedulableForSelection(acc) { + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} + } + if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { + return selectionFailureDiagnosis{ + Category: "platform_filtered", + Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)), } - visited[currentID] = struct{}{} - - group, err := s.resolveGroupByID(ctx, currentID) - if err != nil { - return nil, nil, err + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + return selectionFailureDiagnosis{ + Category: "model_unsupported", + Detail: fmt.Sprintf("model=%s", requestedModel), } - - if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) { - return group, ¤tID, nil + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + return selectionFailureDiagnosis{ + Category: "model_rate_limited", + Detail: fmt.Sprintf("remaining=%s", remaining), } + } + return selectionFailureDiagnosis{Category: "eligible"} +} - if group.FallbackGroupID == nil { - return nil, nil, ErrClaudeCodeOnly +func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { + if acc == nil { + return true + } + if allowMixedScheduling { + if acc.Platform == PlatformAntigravity { + return !acc.IsMixedSchedulingEnabled() } - currentID = *group.FallbackGroupID + return acc.Platform != platform + } + if strings.TrimSpace(platform) == "" { + return false } + return acc.Platform != platform } -// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 -// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: -// - 有降级分组:返回降级分组的 ID -// - 无降级分组:返回 ErrClaudeCodeOnly 错误 -func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) { - if groupID == nil { - return nil, groupID, nil +func appendSelectionFailureSampleID(samples []int64, id int64) []int64 { + const limit = 5 + if len(samples) >= limit { + return samples } + return append(samples, id) +} - // 强制平台模式不检查 Claude Code 限制 - if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { - return nil, groupID, nil +func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string { + const limit = 5 + if len(samples) >= limit { + return samples } + return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining)) +} - group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID) - if err != nil { - return nil, nil, err - } +func summarizeSelectionFailureStats(stats selectionFailureStats) string { + return fmt.Sprintf( + "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d", + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + ) +} - return group, resolvedID, nil +// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) +// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 +func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + if strings.TrimSpace(requestedModel) == "" { + return true + } + // 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底 + mapped := mapAntigravityModel(account, requestedModel) + if mapped == "" { + return false + } + // 应用 thinking 后缀后检查最终模型是否在账号映射中 + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { + finalModel := applyThinkingModelSuffix(mapped, enabled) + if finalModel == mapped { + return true // thinking 后缀未改变模型名,映射已通过 + } + return account.IsModelSupported(finalModel) + } + return true + } + return s.isModelSupportedByAccount(account, requestedModel) } -func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) { - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform != "" { - return forcePlatform, true, nil +// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台) +func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { + if account.Platform == PlatformAntigravity { + if strings.TrimSpace(requestedModel) == "" { + return true + } + return mapAntigravityModel(account, requestedModel) != "" } - if group != nil { - return group.Platform, false, nil + if account.IsBedrock() { + _, ok := ResolveBedrockModelID(account, requestedModel) + return ok } - if groupID != nil { - group, err := s.resolveGroupByID(ctx, *groupID) - if err != nil { - return "", false, err + // OpenAI 透传模式:仅替换认证,允许所有模型 + if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() { + return true + } + // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) + if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { + if account.Type == AccountTypeServiceAccount { + requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel)) + } else { + requestedModel = claude.NormalizeModelID(requestedModel) } - return group.Platform, false, nil } - return PlatformAnthropic, false, nil + // 其他平台使用账户的模型支持检查 + return account.IsModelSupported(requestedModel) } -func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - if s.schedulerSnapshot != nil { - accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err == nil { - slog.Debug("account_scheduling_list_snapshot", - "group_id", derefGroupID(groupID), - "platform", platform, - "use_mixed", useMixed, - "count", len(accounts)) - for _, acc := range accounts { - slog.Debug("account_scheduling_account_detail", - "account_id", acc.ID, - "name", acc.Name, - "platform", acc.Platform, - "type", acc.Type, - "status", acc.Status, - "tls_fingerprint", acc.IsTLSFingerprintEnabled()) - } +// GetAccessToken 获取账号凭证 +func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { + switch account.Type { + case AccountTypeOAuth, AccountTypeSetupToken: + // Both oauth and setup-token use OAuth token flow + return s.getOAuthToken(ctx, account) + case AccountTypeAPIKey: + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return "", "", errors.New("api_key not found in credentials") } - return accounts, useMixed, err - } - useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform - if useMixed { - platforms := []string{platform, PlatformAntigravity} - var accounts []Account - var err error - if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) - } else if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) - } else { - accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatforms(ctx, platforms) + return apiKey, "apikey", nil + case AccountTypeBedrock: + return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理 + case AccountTypeServiceAccount: + if account.Platform != PlatformAnthropic { + return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform) } - if err != nil { - slog.Debug("account_scheduling_list_failed", - "group_id", derefGroupID(groupID), - "platform", platform, - "error", err) - return nil, useMixed, err + if s.claudeTokenProvider == nil { + return "", "", errors.New("claude token provider not configured") } - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - filtered = append(filtered, acc) + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err } - slog.Debug("account_scheduling_list_mixed", - "group_id", derefGroupID(groupID), - "platform", platform, - "raw_count", len(accounts), - "filtered_count", len(filtered)) - for _, acc := range filtered { - slog.Debug("account_scheduling_account_detail", - "account_id", acc.ID, - "name", acc.Name, - "platform", acc.Platform, - "type", acc.Type, - "status", acc.Status, - "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + return accessToken, "service_account", nil + default: + return "", "", fmt.Errorf("unsupported account type: %s", account.Type) + } +} + +func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { + // 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token + if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil { + accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err } - return filtered, useMixed, nil + return accessToken, "oauth", nil } - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) - // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 - } else { - accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, platform) + // 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取 + accessToken := account.GetCredential("access_token") + if accessToken == "" { + return "", "", errors.New("access_token not found in credentials") } - if err != nil { - slog.Debug("account_scheduling_list_failed", - "group_id", derefGroupID(groupID), - "platform", platform, - "error", err) - return nil, useMixed, err + // Token刷新由后台 TokenRefreshService 处理,此处只返回当前token + return accessToken, "oauth", nil +} + +// 重试相关常量 +const ( + // 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。 + maxRetryAttempts = 5 + + // 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。 + retryBaseDelay = 300 * time.Millisecond + retryMaxDelay = 3 * time.Second + + // 最大重试耗时(包含请求本身耗时 + 退避等待时间)。 + // 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。 + maxRetryElapsed = 10 * time.Second +) + +func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { + // OAuth/Setup Token 账号:仅 403 重试 + if account.IsOAuth() { + return statusCode == 403 } - slog.Debug("account_scheduling_list_single", - "group_id", derefGroupID(groupID), - "platform", platform, - "count", len(accounts)) - for _, acc := range accounts { - slog.Debug("account_scheduling_account_detail", - "account_id", acc.ID, - "name", acc.Name, - "platform", acc.Platform, - "type", acc.Type, - "status", acc.Status, - "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + + // API Key 账号:未配置的错误码重试 + return !account.ShouldHandleErrorCode(statusCode) +} + +// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover. +func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 403, 429, 529: + return true + default: + return statusCode >= 500 } - return accounts, useMixed, nil } -// IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 -// 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, -// 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 -func (s *GatewayService) IsSingleAntigravityAccountGroup(ctx context.Context, groupID *int64) bool { - accounts, _, err := s.listSchedulableAccounts(ctx, groupID, PlatformAntigravity, true) - if err != nil { - return false +func retryBackoffDelay(attempt int) time.Duration { + // attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。 + if attempt <= 0 { + return retryBaseDelay } - return len(accounts) == 1 + delay := retryBaseDelay * time.Duration(1<<(attempt-1)) + if delay > retryMaxDelay { + return retryMaxDelay + } + return delay } -func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { - if account == nil { - return false +func sleepWithContext(ctx context.Context, d time.Duration) error { + if d <= 0 { + return nil } - if useMixed { - if account.Platform == platform { - return true + timer := time.NewTimer(d) + defer func() { + if !timer.Stop() { + select { + case <-timer.C: + default: + } } - return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil } - return account.Platform == platform } -func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { - if account == nil { +// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端。 +// 判定条件: +// 1. User-Agent 匹配 claude-cli/X.Y.Z(大小写不敏感) +// 2. metadata.user_id 符合 Claude Code 格式(legacy 或 JSON 格式) +// +// 只检查 metadata.user_id 非空不够严格:第三方工具(opencode 等)可能伪造 UA +// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID +// 验证格式才能确认是真正的 Claude Code 客户端。 +func isClaudeCodeClient(userAgent string, metadataUserID string) bool { + if !claudeCliUserAgentRe.MatchString(userAgent) { return false } - return account.IsSchedulable() + return ParseMetadataUserID(metadataUserID) != nil } -func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { - if account == nil { - return false +// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil), +// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。 +// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。 +func normalizeSystemParam(system any) any { + raw, ok := system.(json.RawMessage) + if !ok { + return system } - return account.IsSchedulableForModelWithContext(ctx, requestedModel) + if len(raw) == 0 { + return nil + } + var parsed any + if err := json.Unmarshal(raw, &parsed); err != nil { + return nil + } + return parsed } -// isAccountInGroup checks if the account belongs to the specified group. -// When groupID is nil, returns true only for ungrouped accounts (no group assignments). -func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { - if account == nil { - return false - } - if groupID == nil { - // 无分组的 API Key 只能使用未分组的账号 - return len(account.AccountGroups) == 0 +// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 +// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) +func systemIncludesClaudeCodePrompt(system any) bool { + system = normalizeSystemParam(system) + switch v := system.(type) { + case string: + return hasClaudeCodePrefix(v) + case []any: + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) { + return true + } + } + } } - for _, ag := range account.AccountGroups { - if ag.GroupID == *groupID { + return false +} + +// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头 +func hasClaudeCodePrefix(text string) bool { + for _, prefix := range claudeCodePromptPrefixes { + if strings.HasPrefix(text, prefix) { return true } } return false } -func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - if s.concurrencyService == nil { - return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil - } - return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) -} - -type usageLogWindowStatsBatchProvider interface { - GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) -} - -type windowCostPrefetchContextKeyType struct{} - -var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{} - -func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) { - if ctx == nil || accountID <= 0 { - return 0, false - } - m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64) - if !ok || len(m) == 0 { - return 0, false - } - v, exists := m[accountID] - return v, exists -} - -func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context { - if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil { - return ctx - } - - accountByID := make(map[int64]*Account) - accountIDs := make([]int64, 0, len(accounts)) - for i := range accounts { - account := &accounts[i] - if account == nil || !account.IsAnthropicOAuthOrSetupToken() { - continue - } - if account.GetWindowCostLimit() <= 0 { - continue - } - accountByID[account.ID] = account - accountIDs = append(accountIDs, account.ID) - } - if len(accountIDs) == 0 { - return ctx - } - - costs := make(map[int64]float64, len(accountIDs)) - cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs) - if err == nil { - for accountID, cost := range cacheValues { - costs[accountID] = cost - } - windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues))) - } else { - windowCostPrefetchErrorTotal.Add(1) - logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err) - } - cacheMissCount := len(accountIDs) - len(costs) - if cacheMissCount < 0 { - cacheMissCount = 0 - } - windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount)) - - missingByStart := make(map[int64][]int64) - startTimes := make(map[int64]time.Time) - for _, accountID := range accountIDs { - if _, ok := costs[accountID]; ok { - continue - } - account := accountByID[accountID] - if account == nil { - continue - } - startTime := account.GetCurrentWindowStartTime() - startKey := startTime.Unix() - missingByStart[startKey] = append(missingByStart[startKey], accountID) - startTimes[startKey] = startTime - } - if len(missingByStart) == 0 { - return context.WithValue(ctx, windowCostPrefetchContextKey, costs) - } - - batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider) - for startKey, ids := range missingByStart { - startTime := startTimes[startKey] - - if hasBatch { - windowCostPrefetchBatchSQLTotal.Add(1) - queryStart := time.Now() - statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime) - if err == nil { - slog.Debug("window_cost_batch_query_ok", - "accounts", len(ids), - "window_start", startTime.Format(time.RFC3339), - "duration_ms", time.Since(queryStart).Milliseconds()) - for _, accountID := range ids { - stats := statsByAccount[accountID] - cost := 0.0 - if stats != nil { - cost = stats.StandardCost - } - costs[accountID] = cost - _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) - } - continue - } - windowCostPrefetchErrorTotal.Add(1) - logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err) - } - - // 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。 - windowCostPrefetchFallbackTotal.Add(int64(len(ids))) - for _, accountID := range ids { - stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) - if err != nil { - windowCostPrefetchErrorTotal.Add(1) - continue - } - cost := stats.StandardCost - costs[accountID] = cost - _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) - } - } - - return context.WithValue(ctx, windowCostPrefetchContextKey, costs) -} - -// isAccountSchedulableForQuota 检查账号是否在配额限制内 -// 适用于配置了 quota_limit 的 apikey 和 bedrock 类型账号 -func (s *GatewayService) isAccountSchedulableForQuota(account *Account) bool { - if !account.IsAPIKeyOrBedrock() { - return true - } - return !account.IsQuotaExceeded() -} - -// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 -// 仅适用于 Anthropic OAuth/SetupToken 账号 -// 返回 true 表示可调度,false 表示不可调度 -func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool { - // 只检查 Anthropic OAuth/SetupToken 账号 - if !account.IsAnthropicOAuthOrSetupToken() { - return true - } - - limit := account.GetWindowCostLimit() - if limit <= 0 { - return true // 未启用窗口费用限制 - } - - // 尝试从缓存获取窗口费用 - var currentCost float64 - if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok { - currentCost = cost - goto checkSchedulability - } - if s.sessionLimitCache != nil { - if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { - currentCost = cost - goto checkSchedulability - } - } - - // 缓存未命中,从数据库查询 - { - // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) - startTime := account.GetCurrentWindowStartTime() - - stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) - if err != nil { - // 失败开放:查询失败时允许调度 - return true - } - - // 使用标准费用(不含账号倍率) - currentCost = stats.StandardCost - - // 设置缓存(忽略错误) - if s.sessionLimitCache != nil { - _ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost) - } - } - -checkSchedulability: - schedulability := account.CheckWindowCostSchedulability(currentCost) - - switch schedulability { - case WindowCostSchedulable: - return true - case WindowCostStickyOnly: - return isSticky - case WindowCostNotSchedulable: - return false - } - return true -} - -// rpmPrefetchContextKey is the context key for prefetched RPM counts. -type rpmPrefetchContextKeyType struct{} - -var rpmPrefetchContextKey = rpmPrefetchContextKeyType{} - -func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) { - if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok { - count, found := v[accountID] - return count, found - } - return 0, false -} - -// withRPMPrefetch 批量预取所有候选账号的 RPM 计数 -func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context { - if s.rpmCache == nil { - return ctx - } - - var ids []int64 - for i := range accounts { - if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 { - ids = append(ids, accounts[i].ID) - } - } - if len(ids) == 0 { - return ctx - } - - counts, err := s.rpmCache.GetRPMBatch(ctx, ids) - if err != nil { - return ctx // 失败开放 - } - return context.WithValue(ctx, rpmPrefetchContextKey, counts) -} - -// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度 -// 仅适用于 Anthropic OAuth/SetupToken 账号 -func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool { - if !account.IsAnthropicOAuthOrSetupToken() { - return true - } - baseRPM := account.GetBaseRPM() - if baseRPM <= 0 { - return true - } - - // 尝试从预取缓存获取 - var currentRPM int - if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok { - currentRPM = count - } else if s.rpmCache != nil { - if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil { - currentRPM = count - } - // 失败开放:GetRPM 错误时允许调度 - } - - schedulability := account.CheckRPMSchedulability(currentRPM) - switch schedulability { - case WindowCostSchedulable: - return true - case WindowCostStickyOnly: - return isSticky - case WindowCostNotSchedulable: - return false - } - return true -} - -// IncrementAccountRPM increments the RPM counter for the given account. -// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, -// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit -// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 -func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { - if s.rpmCache == nil { - return nil - } - _, err := s.rpmCache.IncrementRPM(ctx, accountID) - return err -} - -// checkAndRegisterSession 检查并注册会话,用于会话数量限制 -// 仅适用于 Anthropic OAuth/SetupToken 账号 -// sessionID: 会话标识符(使用粘性会话的 hash) -// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) -func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool { - // 只检查 Anthropic OAuth/SetupToken 账号 - if !account.IsAnthropicOAuthOrSetupToken() { - return true - } - - maxSessions := account.GetMaxSessions() - if maxSessions <= 0 || sessionID == "" { - return true // 未启用会话限制或无会话ID - } - - if s.sessionLimitCache == nil { - return true // 缓存不可用时允许通过 - } - - idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute - - allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout) - if err != nil { - // 失败开放:缓存错误时允许通过 - return true - } - return allowed -} - -func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { - if s.schedulerSnapshot != nil { - return s.schedulerSnapshot.GetAccount(ctx, accountID) - } - return s.accountRepo.GetByID(ctx, accountID) -} - -func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { - if account == nil || s.schedulerSnapshot == nil { - return account, nil - } - hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) - if err != nil { - return nil, err - } - if hydrated == nil { - return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) - } - return hydrated, nil -} - -func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { - hydrated, err := s.hydrateSelectedAccount(ctx, account) - if err != nil { - return nil, err - } - return &AccountSelectionResult{ - Account: hydrated, - Acquired: acquired, - ReleaseFunc: release, - WaitPlan: waitPlan, - }, nil -} - -// filterByMinPriority 过滤出优先级最小的账号集合 -func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { - if len(accounts) == 0 { - return accounts - } - minPriority := accounts[0].account.Priority - for _, acc := range accounts[1:] { - if acc.account.Priority < minPriority { - minPriority = acc.account.Priority - } - } - result := make([]accountWithLoad, 0, len(accounts)) - for _, acc := range accounts { - if acc.account.Priority == minPriority { - result = append(result, acc) - } - } - return result -} - -// filterByMinLoadRate 过滤出负载率最低的账号集合 -func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { - if len(accounts) == 0 { - return accounts - } - minLoadRate := accounts[0].loadInfo.LoadRate - for _, acc := range accounts[1:] { - if acc.loadInfo.LoadRate < minLoadRate { - minLoadRate = acc.loadInfo.LoadRate - } - } - result := make([]accountWithLoad, 0, len(accounts)) - for _, acc := range accounts { - if acc.loadInfo.LoadRate == minLoadRate { - result = append(result, acc) - } - } - return result -} - -// selectByLRU 从集合中选择最久未用的账号 -// 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 -func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { - if len(accounts) == 0 { - return nil - } - if len(accounts) == 1 { - return &accounts[0] - } - - // 1. 找到最小的 LastUsedAt(nil 被视为最小) - var minTime *time.Time - hasNil := false - for _, acc := range accounts { - if acc.account.LastUsedAt == nil { - hasNil = true - break - } - if minTime == nil || acc.account.LastUsedAt.Before(*minTime) { - minTime = acc.account.LastUsedAt - } - } - - // 2. 收集所有具有最小 LastUsedAt 的账号索引 - var candidateIdxs []int - for i, acc := range accounts { - if hasNil { - if acc.account.LastUsedAt == nil { - candidateIdxs = append(candidateIdxs, i) - } - } else { - if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) { - candidateIdxs = append(candidateIdxs, i) - } - } - } - - // 3. 如果只有一个候选,直接返回 - if len(candidateIdxs) == 1 { - return &accounts[candidateIdxs[0]] - } - - // 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型 - if preferOAuth { - var oauthIdxs []int - for _, idx := range candidateIdxs { - if accounts[idx].account.Type == AccountTypeOAuth { - oauthIdxs = append(oauthIdxs, idx) - } - } - if len(oauthIdxs) > 0 { - candidateIdxs = oauthIdxs - } - } - - // 5. 随机选择一个 - selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))] - return &accounts[selectedIdx] -} - -func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { - sort.SliceStable(accounts, func(i, j int) bool { - a, b := accounts[i], accounts[j] - if a.Priority != b.Priority { - return a.Priority < b.Priority - } - switch { - case a.LastUsedAt == nil && b.LastUsedAt != nil: - return true - case a.LastUsedAt != nil && b.LastUsedAt == nil: - return false - case a.LastUsedAt == nil && b.LastUsedAt == nil: - if preferOAuth && a.Type != b.Type { - return a.Type == AccountTypeOAuth - } - return false - default: - return a.LastUsedAt.Before(*b.LastUsedAt) - } - }) - shuffleWithinPriorityAndLastUsed(accounts, preferOAuth) -} - -// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 -// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。 -func shuffleWithinSortGroups(accounts []accountWithLoad) { - if len(accounts) <= 1 { - return - } - i := 0 - for i < len(accounts) { - j := i + 1 - for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) { - j++ - } - if j-i > 1 { - mathrand.Shuffle(j-i, func(a, b int) { - accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] - }) - } - i = j - } -} - -// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组 -func sameAccountWithLoadGroup(a, b accountWithLoad) bool { - if a.account.Priority != b.account.Priority { - return false - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return false - } - return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt) -} - -// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 -// -// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 -// 因此这里采用"组内分区 + 分区内 shuffle"的方式: -// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; -// - 再分别在各段内随机打散,避免热点。 -func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { - if len(accounts) <= 1 { - return - } - i := 0 - for i < len(accounts) { - j := i + 1 - for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) { - j++ - } - if j-i > 1 { - if preferOAuth { - oauth := make([]*Account, 0, j-i) - others := make([]*Account, 0, j-i) - for _, acc := range accounts[i:j] { - if acc.Type == AccountTypeOAuth { - oauth = append(oauth, acc) - } else { - others = append(others, acc) - } - } - if len(oauth) > 1 { - mathrand.Shuffle(len(oauth), func(a, b int) { oauth[a], oauth[b] = oauth[b], oauth[a] }) - } - if len(others) > 1 { - mathrand.Shuffle(len(others), func(a, b int) { others[a], others[b] = others[b], others[a] }) - } - copy(accounts[i:], oauth) - copy(accounts[i+len(oauth):], others) - } else { - mathrand.Shuffle(j-i, func(a, b int) { - accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] - }) - } - } - i = j - } -} - -// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt) -func sameAccountGroup(a, b *Account) bool { - if a.Priority != b.Priority { - return false - } - return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt) -} - -// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒) -func sameLastUsedAt(a, b *time.Time) bool { - switch { - case a == nil && b == nil: - return true - case a == nil || b == nil: - return false - default: - return a.Unix() == b.Unix() - } -} - -// sortCandidatesForFallback 根据配置选择排序策略 -// mode: "last_used"(按最后使用时间) 或 "random"(随机) -func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { - if mode == "random" { - // 先按优先级排序,然后在同优先级内随机打乱 - sortAccountsByPriorityOnly(accounts, preferOAuth) - shuffleWithinPriority(accounts) - } else { - // 默认按最后使用时间排序 - sortAccountsByPriorityAndLastUsed(accounts, preferOAuth) - } -} - -// sortAccountsByPriorityOnly 仅按优先级排序 -func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) { - sort.SliceStable(accounts, func(i, j int) bool { - a, b := accounts[i], accounts[j] - if a.Priority != b.Priority { - return a.Priority < b.Priority - } - if preferOAuth && a.Type != b.Type { - return a.Type == AccountTypeOAuth - } - return false - }) -} - -// shuffleWithinPriority 在同优先级内随机打乱顺序 -func shuffleWithinPriority(accounts []*Account) { - if len(accounts) <= 1 { - return - } - r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) - start := 0 - for start < len(accounts) { - priority := accounts[start].Priority - end := start + 1 - for end < len(accounts) && accounts[end].Priority == priority { - end++ - } - // 对 [start, end) 范围内的账户随机打乱 - if end-start > 1 { - r.Shuffle(end-start, func(i, j int) { - accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i] - }) - } - start = end - } -} - -// selectAccountForModelWithPlatform 选择单平台账户(完全隔离) -func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { - preferOAuth := platform == PlatformGemini - routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) - - // require_privacy_set: 获取分组信息 - var schedGroup *Group - if groupID != nil && s.groupRepo != nil { - schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) - } - - var accounts []Account - accountsLoaded := false - - // ============ Model Routing (legacy path): apply before sticky session ============ - // When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing - // so switching model can switch upstream account within the same sticky session. - if len(routingAccountIDs) > 0 { - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", - derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) - } - // 1) Sticky session only applies if the bound account is within the routing set. - if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.getSchedulableAccount(ctx, accountID) - // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) - if err == nil { - clearSticky := shouldClearStickySession(account, requestedModel) - if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) - } - return account, nil - } - } - } - } - } - - // 2) Select an account from the routed candidates. - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform == "" { - hasForcePlatform = false - } - var err error - accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - accountsLoaded = true - - // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - - routingSet := make(map[int64]struct{}, len(routingAccountIDs)) - for _, id := range routingAccountIDs { - if id > 0 { - routingSet[id] = struct{}{} - } - } - - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, ok := routingSet[acc.ID]; !ok { - continue - } - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - // Scheduler snapshots can be temporarily stale; re-check schedulability here to - // avoid selecting accounts that were recently rate-limited/overloaded. - if !s.isAccountSchedulableForSelection(acc) { - continue - } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } - if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForQuota(acc) { - continue - } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } - if selected == nil { - selected = acc - continue - } - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } - default: - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } - - if selected != nil { - if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) - } - } - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) - } - return selected, nil - } - logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) - } - - // 1. 查询粘性会话 - if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.getSchedulableAccount(ctx, accountID) - // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) - if err == nil { - clearSticky := shouldClearStickySession(account, requestedModel) - if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { - return account, nil - } - } - } - } - } - - // 2. 获取可调度账号列表(单平台) - if !accountsLoaded { - forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) - if hasForcePlatform && forcePlatform == "" { - hasForcePlatform = false - } - var err error - accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - } - - // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - - // 3. 按优先级+最久未用选择(考虑模型支持) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, - // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 - needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - // Scheduler snapshots can be temporarily stale; re-check schedulability here to - // avoid selecting accounts that were recently rate-limited/overloaded. - if !s.isAccountSchedulableForSelection(acc) { - continue - } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } - if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { - continue - } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForQuota(acc) { - continue - } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } - if selected == nil { - selected = acc - continue - } - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } - default: - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } - - if selected == nil { - stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) - if requestedModel != "" { - return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) - } - return nil, ErrNoAvailableAccounts - } - - // 4. 建立粘性绑定 - if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) - } - } - - return selected, nil -} - -// selectAccountWithMixedScheduling 选择账户(支持混合调度) -// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 -func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { - preferOAuth := nativePlatform == PlatformGemini - routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) - - // require_privacy_set: 获取分组信息 - var schedGroup *Group - if groupID != nil && s.groupRepo != nil { - schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) - } - - var accounts []Account - accountsLoaded := false - - // ============ Model Routing (legacy path): apply before sticky session ============ - if len(routingAccountIDs) > 0 { - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", - derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) - } - // 1) Sticky session only applies if the bound account is within the routing set. - if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.getSchedulableAccount(ctx, accountID) - // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil { - clearSticky := shouldClearStickySession(account, requestedModel) - if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { - if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) - } - return account, nil - } - } - } - } - } - } - - // 2) Select an account from the routed candidates. - var err error - accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - accountsLoaded = true - - // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - - routingSet := make(map[int64]struct{}, len(routingAccountIDs)) - for _, id := range routingAccountIDs { - if id > 0 { - routingSet[id] = struct{}{} - } - } - - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, ok := routingSet[acc.ID]; !ok { - continue - } - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - // Scheduler snapshots can be temporarily stale; re-check schedulability here to - // avoid selecting accounts that were recently rate-limited/overloaded. - if !s.isAccountSchedulableForSelection(acc) { - continue - } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } - // 过滤:原生平台直接通过,antigravity 需要启用混合调度 - if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForQuota(acc) { - continue - } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } - if selected == nil { - selected = acc - continue - } - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } - default: - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } - - if selected != nil { - if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) - } - } - if s.debugModelRoutingEnabled() { - logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) - } - return selected, nil - } - logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) - } - - // 1. 查询粘性会话 - if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 { - if _, excluded := excludedIDs[accountID]; !excluded { - account, err := s.getSchedulableAccount(ctx, accountID) - // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 - if err == nil { - clearSticky := shouldClearStickySession(account, requestedModel) - if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) { - if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - return account, nil - } - } - } - } - } - } - - // 2. 获取可调度账号列表 - if !accountsLoaded { - var err error - accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) - } - } - - // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - - // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 - needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) - var selected *Account - for i := range accounts { - acc := &accounts[i] - if _, excluded := excludedIDs[acc.ID]; excluded { - continue - } - // Scheduler snapshots can be temporarily stale; re-check schedulability here to - // avoid selecting accounts that were recently rate-limited/overloaded. - if !s.isAccountSchedulableForSelection(acc) { - continue - } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } - // 过滤:原生平台直接通过,antigravity 需要启用混合调度 - if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { - continue - } - if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { - continue - } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { - continue - } - if !s.isAccountSchedulableForQuota(acc) { - continue - } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } - if selected == nil { - selected = acc - continue - } - if acc.Priority < selected.Priority { - selected = acc - } else if acc.Priority == selected.Priority { - switch { - case acc.LastUsedAt == nil && selected.LastUsedAt != nil: - selected = acc - case acc.LastUsedAt != nil && selected.LastUsedAt == nil: - // keep selected (never used is preferred) - case acc.LastUsedAt == nil && selected.LastUsedAt == nil: - if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { - selected = acc - } - default: - if acc.LastUsedAt.Before(*selected.LastUsedAt) { - selected = acc - } - } - } - } - - if selected == nil { - stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) - if requestedModel != "" { - return nil, fmt.Errorf("%w supporting model: %s (%s)", ErrNoAvailableAccounts, requestedModel, summarizeSelectionFailureStats(stats)) - } - return nil, ErrNoAvailableAccounts - } - - // 4. 建立粘性绑定 - if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) - } - } - - return selected, nil -} - -type selectionFailureStats struct { - Total int - Eligible int - Excluded int - Unschedulable int - PlatformFiltered int - ModelUnsupported int - ModelRateLimited int - SamplePlatformIDs []int64 - SampleMappingIDs []int64 - SampleRateLimitIDs []string -} - -type selectionFailureDiagnosis struct { - Category string - Detail string -} - -func (s *GatewayService) logDetailedSelectionFailure( - ctx context.Context, - groupID *int64, - sessionHash string, - requestedModel string, - platform string, - accounts []Account, - excludedIDs map[int64]struct{}, - allowMixedScheduling bool, -) selectionFailureStats { - stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling) - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v", - derefGroupID(groupID), - requestedModel, - platform, - shortSessionHash(sessionHash), - stats.Total, - stats.Eligible, - stats.Excluded, - stats.Unschedulable, - stats.PlatformFiltered, - stats.ModelUnsupported, - stats.ModelRateLimited, - stats.SamplePlatformIDs, - stats.SampleMappingIDs, - stats.SampleRateLimitIDs, - ) - return stats -} - -func (s *GatewayService) collectSelectionFailureStats( - ctx context.Context, - accounts []Account, - requestedModel string, - platform string, - excludedIDs map[int64]struct{}, - allowMixedScheduling bool, -) selectionFailureStats { - stats := selectionFailureStats{ - Total: len(accounts), - } - - for i := range accounts { - acc := &accounts[i] - diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling) - switch diagnosis.Category { - case "excluded": - stats.Excluded++ - case "unschedulable": - stats.Unschedulable++ - case "platform_filtered": - stats.PlatformFiltered++ - stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID) - case "model_unsupported": - stats.ModelUnsupported++ - stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID) - case "model_rate_limited": - stats.ModelRateLimited++ - remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) - stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining) - default: - stats.Eligible++ - } - } - - return stats -} - -func (s *GatewayService) diagnoseSelectionFailure( - ctx context.Context, - acc *Account, - requestedModel string, - platform string, - excludedIDs map[int64]struct{}, - allowMixedScheduling bool, -) selectionFailureDiagnosis { - if acc == nil { - return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"} - } - if _, excluded := excludedIDs[acc.ID]; excluded { - return selectionFailureDiagnosis{Category: "excluded"} - } - if !s.isAccountSchedulableForSelection(acc) { - return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} - } - if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { - return selectionFailureDiagnosis{ - Category: "platform_filtered", - Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)), - } - } - if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { - return selectionFailureDiagnosis{ - Category: "model_unsupported", - Detail: fmt.Sprintf("model=%s", requestedModel), - } - } - if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { - remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) - return selectionFailureDiagnosis{ - Category: "model_rate_limited", - Detail: fmt.Sprintf("remaining=%s", remaining), - } - } - return selectionFailureDiagnosis{Category: "eligible"} -} - -func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { - if acc == nil { - return true - } - if allowMixedScheduling { - if acc.Platform == PlatformAntigravity { - return !acc.IsMixedSchedulingEnabled() - } - return acc.Platform != platform - } - if strings.TrimSpace(platform) == "" { - return false - } - return acc.Platform != platform -} - -func appendSelectionFailureSampleID(samples []int64, id int64) []int64 { - const limit = 5 - if len(samples) >= limit { - return samples - } - return append(samples, id) -} - -func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string { - const limit = 5 - if len(samples) >= limit { - return samples - } - return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining)) -} - -func summarizeSelectionFailureStats(stats selectionFailureStats) string { - return fmt.Sprintf( - "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d", - stats.Total, - stats.Eligible, - stats.Excluded, - stats.Unschedulable, - stats.PlatformFiltered, - stats.ModelUnsupported, - stats.ModelRateLimited, - ) -} - -// isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) -// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 -func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { - if account.Platform == PlatformAntigravity { - if strings.TrimSpace(requestedModel) == "" { - return true - } - // 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底 - mapped := mapAntigravityModel(account, requestedModel) - if mapped == "" { - return false - } - // 应用 thinking 后缀后检查最终模型是否在账号映射中 - if enabled, ok := ThinkingEnabledFromContext(ctx); ok { - finalModel := applyThinkingModelSuffix(mapped, enabled) - if finalModel == mapped { - return true // thinking 后缀未改变模型名,映射已通过 - } - return account.IsModelSupported(finalModel) - } - return true - } - return s.isModelSupportedByAccount(account, requestedModel) -} - -// isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台) -func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { - if account.Platform == PlatformAntigravity { - if strings.TrimSpace(requestedModel) == "" { - return true - } - return mapAntigravityModel(account, requestedModel) != "" - } - if account.IsBedrock() { - _, ok := ResolveBedrockModelID(account, requestedModel) - return ok - } - // OpenAI 透传模式:仅替换认证,允许所有模型 - if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() { - return true - } - // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) - if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { - if account.Type == AccountTypeServiceAccount { - requestedModel = normalizeVertexAnthropicModelID(claude.NormalizeModelID(requestedModel)) - } else { - requestedModel = claude.NormalizeModelID(requestedModel) - } - } - // 其他平台使用账户的模型支持检查 - return account.IsModelSupported(requestedModel) -} - -// GetAccessToken 获取账号凭证 -func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { - switch account.Type { - case AccountTypeOAuth, AccountTypeSetupToken: - // Both oauth and setup-token use OAuth token flow - return s.getOAuthToken(ctx, account) - case AccountTypeAPIKey: - apiKey := account.GetCredential("api_key") - if apiKey == "" { - return "", "", errors.New("api_key not found in credentials") - } - return apiKey, "apikey", nil - case AccountTypeBedrock: - return "", "bedrock", nil // Bedrock 使用 SigV4 签名或 API Key,由 forwardBedrock 处理 - case AccountTypeServiceAccount: - if account.Platform != PlatformAnthropic { - return "", "", fmt.Errorf("unsupported service account platform: %s", account.Platform) - } - if s.claudeTokenProvider == nil { - return "", "", errors.New("claude token provider not configured") - } - accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) - if err != nil { - return "", "", err - } - return accessToken, "service_account", nil - default: - return "", "", fmt.Errorf("unsupported account type: %s", account.Type) - } -} - -func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { - // 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token - if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil { - accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) - if err != nil { - return "", "", err - } - return accessToken, "oauth", nil - } - - // 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取 - accessToken := account.GetCredential("access_token") - if accessToken == "" { - return "", "", errors.New("access_token not found in credentials") - } - // Token刷新由后台 TokenRefreshService 处理,此处只返回当前token - return accessToken, "oauth", nil -} - -// 重试相关常量 -const ( - // 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。 - maxRetryAttempts = 5 - - // 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。 - retryBaseDelay = 300 * time.Millisecond - retryMaxDelay = 3 * time.Second - - // 最大重试耗时(包含请求本身耗时 + 退避等待时间)。 - // 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。 - maxRetryElapsed = 10 * time.Second -) - -func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { - // OAuth/Setup Token 账号:仅 403 重试 - if account.IsOAuth() { - return statusCode == 403 - } - - // API Key 账号:未配置的错误码重试 - return !account.ShouldHandleErrorCode(statusCode) -} - -// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover. -func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { - switch statusCode { - case 401, 403, 429, 529: - return true - default: - return statusCode >= 500 - } -} - -func retryBackoffDelay(attempt int) time.Duration { - // attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。 - if attempt <= 0 { - return retryBaseDelay - } - delay := retryBaseDelay * time.Duration(1<<(attempt-1)) - if delay > retryMaxDelay { - return retryMaxDelay - } - return delay -} - -func sleepWithContext(ctx context.Context, d time.Duration) error { - if d <= 0 { - return nil - } - timer := time.NewTimer(d) - defer func() { - if !timer.Stop() { - select { - case <-timer.C: - default: - } - } - }() - - select { - case <-ctx.Done(): - return ctx.Err() - case <-timer.C: - return nil - } -} - -// isClaudeCodeClient 判断请求是否来自真正的 Claude Code 客户端。 -// 判定条件: -// 1. User-Agent 匹配 claude-cli/X.Y.Z(大小写不敏感) -// 2. metadata.user_id 符合 Claude Code 格式(legacy 或 JSON 格式) -// -// 只检查 metadata.user_id 非空不够严格:第三方工具(opencode 等)可能伪造 UA -// 并附带任意 metadata.user_id 字符串,从而绕过 mimicry。必须通过 ParseMetadataUserID -// 验证格式才能确认是真正的 Claude Code 客户端。 -func isClaudeCodeClient(userAgent string, metadataUserID string) bool { - if !claudeCliUserAgentRe.MatchString(userAgent) { - return false - } - return ParseMetadataUserID(metadataUserID) != nil -} - -// normalizeSystemParam 将 json.RawMessage 类型的 system 参数转为标准 Go 类型(string / []any / nil), -// 避免 type switch 中 json.RawMessage(底层 []byte)无法匹配 case string / case []any / case nil 的问题。 -// 这是 Go 的 typed nil 陷阱:(json.RawMessage, nil) ≠ (nil, nil)。 -func normalizeSystemParam(system any) any { - raw, ok := system.(json.RawMessage) - if !ok { - return system - } - if len(raw) == 0 { - return nil - } - var parsed any - if err := json.Unmarshal(raw, &parsed); err != nil { - return nil - } - return parsed -} - -// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 -// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) -func systemIncludesClaudeCodePrompt(system any) bool { - system = normalizeSystemParam(system) - switch v := system.(type) { - case string: - return hasClaudeCodePrefix(v) - case []any: - for _, item := range v { - if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) { - return true - } - } - } - } - return false -} - -// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头 -func hasClaudeCodePrefix(text string) bool { - for _, prefix := range claudeCodePromptPrefixes { - if strings.HasPrefix(text, prefix) { - return true - } - } - return false -} - -// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 -// 处理 null、字符串、数组三种格式 -func injectClaudeCodePrompt(body []byte, system any) []byte { - system = normalizeSystemParam(system) - claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) - if err != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err) - return body - } - // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code - // banner, it also prefixes the next system instruction with the same banner plus - // a blank line. This helps when upstream concatenates system instructions. - claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) - - var items [][]byte - - switch v := system.(type) { - case nil: - items = [][]byte{claudeCodeBlock} - case string: - // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. - if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { - items = [][]byte{claudeCodeBlock} - } else { - // Mirror opencode behavior: keep the banner as a separate system entry, - // but also prefix the next system text with the banner. - merged := v - if !strings.HasPrefix(v, claudeCodePrefix) { - merged = claudeCodePrefix + "\n\n" + v - } - nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false) - if buildErr != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr) - return body - } - items = [][]byte{claudeCodeBlock, nextBlock} - } - case []any: - items = make([][]byte, 0, len(v)+1) - items = append(items, claudeCodeBlock) - prefixedNext := false - systemResult := gjson.GetBytes(body, "system") - if systemResult.IsArray() { - systemResult.ForEach(func(_, item gjson.Result) bool { - textResult := item.Get("text") - if textResult.Exists() && textResult.Type == gjson.String && - strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) { - return true - } - - raw := []byte(item.Raw) - // Prefix the first subsequent text system block once. - if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String { - text := textResult.String() - if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { - next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text) - if setErr == nil { - raw = next - prefixedNext = true - } - } - } - items = append(items, raw) - return true - }) - } else { - for _, item := range v { - m, ok := item.(map[string]any) - if !ok { - raw, marshalErr := json.Marshal(item) - if marshalErr == nil { - items = append(items, raw) - } - continue - } - if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { - continue - } - if !prefixedNext { - if blockType, _ := m["type"].(string); blockType == "text" { - if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { - m["text"] = claudeCodePrefix + "\n\n" + text - prefixedNext = true - } - } - } - raw, marshalErr := json.Marshal(m) - if marshalErr == nil { - items = append(items, raw) - } - } - } - default: - items = [][]byte{claudeCodeBlock} - } - - result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items)) - if !ok { - logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt") - return body - } - return result -} - -// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages, -// system 字段仅保留 Claude Code 标识提示词。 -// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词 -// 无法通过检测,因为后续内容仍为非 Claude Code 格式。 -// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。 -func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { - system = normalizeSystemParam(system) - - // 1. 提取原始 system prompt 文本 - var originalSystemText string - switch v := system.(type) { - case string: - originalSystemText = strings.TrimSpace(v) - case []any: - var parts []string - for _, item := range v { - if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" { - parts = append(parts, text) - } - } - } - originalSystemText = strings.Join(parts, "\n\n") - } - - // 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态: - // [0] billing attribution block(cc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;) - // [1] "You are Claude Code..." prompt block(带 cache_control 作为稳定缓存断点) - // - // billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的 - // signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload - // 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。 - billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion) - ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) - if billingErr != nil || ccErr != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr) - return body - } - out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock})) - if !ok { - logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") - return body - } - - // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头 - // 模型仍通过 messages 接收完整指令,保留客户端功能 - ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt) - if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) { - instrMsg, err1 := json.Marshal(map[string]any{ - "role": "user", - "content": []map[string]any{ - {"type": "text", "text": "[System Instructions]\n" + originalSystemText}, - }, - }) - ackMsg, err2 := json.Marshal(map[string]any{ - "role": "assistant", - "content": []map[string]any{ - {"type": "text", "text": "Understood. I will follow these instructions."}, - }, - }) - if err1 != nil || err2 != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection") - return out - } - - // 重建 messages 数组:[instruction, ack, ...originalMessages] - items := [][]byte{instrMsg, ackMsg} - messagesResult := gjson.GetBytes(out, "messages") - if messagesResult.IsArray() { - messagesResult.ForEach(func(_, msg gjson.Result) bool { - items = append(items, []byte(msg.Raw)) - return true - }) - } - - if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk { - out = next - } - } - - return out -} - -type cacheControlPath struct { - path string - log string -} - -func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, toolPaths []string, systemPaths []string) { - system := gjson.GetBytes(body, "system") - if system.IsArray() { - sysIndex := 0 - system.ForEach(func(_, item gjson.Result) bool { - if item.Get("cache_control").Exists() { - path := fmt.Sprintf("system.%d.cache_control", sysIndex) - if item.Get("type").String() == "thinking" { - invalidThinking = append(invalidThinking, cacheControlPath{ - path: path, - log: "[Warning] Removed illegal cache_control from thinking block in system", - }) - } else { - systemPaths = append(systemPaths, path) - } - } - sysIndex++ - return true - }) - } - - messages := gjson.GetBytes(body, "messages") - if messages.IsArray() { - msgIndex := 0 - messages.ForEach(func(_, msg gjson.Result) bool { - content := msg.Get("content") - if content.IsArray() { - contentIndex := 0 - content.ForEach(func(_, item gjson.Result) bool { - if item.Get("cache_control").Exists() { - path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex) - if item.Get("type").String() == "thinking" { - invalidThinking = append(invalidThinking, cacheControlPath{ - path: path, - log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex), - }) - } else { - messagePaths = append(messagePaths, path) - } - } - contentIndex++ - return true - }) - } - msgIndex++ - return true - }) - } - - tools := gjson.GetBytes(body, "tools") - if tools.IsArray() { - toolIndex := 0 - tools.ForEach(func(_, tool gjson.Result) bool { - if tool.Get("cache_control").Exists() { - toolPaths = append(toolPaths, fmt.Sprintf("tools.%d.cache_control", toolIndex)) - } - toolIndex++ - return true - }) - } - - return invalidThinking, messagePaths, toolPaths, systemPaths -} - -// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) -// 超限时优先移除工具断点,再移除 messages 断点,最后才移除 system 断点。 -func enforceCacheControlLimit(body []byte) []byte { - if len(body) == 0 { - return body - } - - invalidThinking, messagePaths, toolPaths, systemPaths := collectCacheControlPaths(body) - out := body - modified := false - - // 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) - for _, item := range invalidThinking { - if !gjson.GetBytes(out, item.path).Exists() { - continue - } - next, ok := deleteJSONPathBytes(out, item.path) - if !ok { - continue - } - out = next - modified = true - logger.LegacyPrintf("service.gateway", "%s", item.log) - } - - count := len(messagePaths) + len(toolPaths) + len(systemPaths) - if count <= maxCacheControlBlocks { - if modified { - return out - } - return body - } - - // 超限:优先从 tools 中移除,再从 messages 中移除,最后才从 system 中移除。 - remaining := count - maxCacheControlBlocks - for i := len(toolPaths) - 1; i >= 0 && remaining > 0; i-- { - path := toolPaths[i] - if !gjson.GetBytes(out, path).Exists() { - continue - } - next, ok := deleteJSONPathBytes(out, path) - if !ok { - continue - } - out = next - modified = true - remaining-- - } - - for _, path := range messagePaths { - if remaining <= 0 { - break - } - if !gjson.GetBytes(out, path).Exists() { - continue - } - next, ok := deleteJSONPathBytes(out, path) - if !ok { - continue - } - out = next - modified = true - remaining-- - } - - for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- { - path := systemPaths[i] - if !gjson.GetBytes(out, path).Exists() { - continue - } - next, ok := deleteJSONPathBytes(out, path) - if !ok { - continue - } - out = next - modified = true - remaining-- - } - - if modified { - return out - } - return body -} - -// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。 -// 仅修改已经存在的 cache_control,不新增缓存断点。 -func injectAnthropicCacheControlTTL1h(body []byte) []byte { - return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h) -} - -func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte { - if len(body) == 0 || ttl == "" { - return body - } - out := body - var paths []string - addPath := func(path string, value gjson.Result) { - cc := value.Get("cache_control") - if !cc.Exists() || cc.Get("type").String() != "ephemeral" { - return - } - if cc.Get("ttl").String() == ttl { - return - } - paths = append(paths, path+".cache_control.ttl") - } - - if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl { - paths = append(paths, "cache_control.ttl") - } - - system := gjson.GetBytes(body, "system") - if system.IsArray() { - idx := -1 - system.ForEach(func(_, block gjson.Result) bool { - idx++ - addPath(fmt.Sprintf("system.%d", idx), block) - return true - }) - } - - messages := gjson.GetBytes(body, "messages") - if messages.IsArray() { - msgIdx := -1 - messages.ForEach(func(_, msg gjson.Result) bool { - msgIdx++ - content := msg.Get("content") - if !content.IsArray() { - return true - } - contentIdx := -1 - content.ForEach(func(_, block gjson.Result) bool { - contentIdx++ - addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block) - return true - }) - return true - }) - } - - tools := gjson.GetBytes(body, "tools") - if tools.IsArray() { - idx := -1 - tools.ForEach(func(_, tool gjson.Result) bool { - idx++ - addPath(fmt.Sprintf("tools.%d", idx), tool) - return true - }) - } - - for _, path := range paths { - if next, err := sjson.SetBytes(out, path, ttl); err == nil { - out = next - } - } - return out -} - -func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool { - if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil { - return false - } - return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) -} - -// Forward 转发请求到Claude API -func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { - startTime := time.Now() - if parsed == nil { - return nil, fmt.Errorf("parse request: empty request") - } - - // Web Search 模拟:纯 web_search 请求时,直接调用搜索 API 构造响应 - if account != nil && s.shouldEmulateWebSearch(ctx, account, parsed.GroupID, parsed.Body) { - return s.handleWebSearchEmulation(ctx, c, account, parsed) - } - - if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - passthroughBody := parsed.Body - passthroughModel := parsed.Model - if passthroughModel != "" { - if mappedModel := account.GetMappedModel(passthroughModel); mappedModel != passthroughModel { - passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) - logger.LegacyPrintf("service.gateway", "Passthrough model mapping: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name) - passthroughModel = mappedModel - } - } - return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ - Body: passthroughBody, - RequestModel: passthroughModel, - OriginalModel: parsed.Model, - RequestStream: parsed.Stream, - StartTime: startTime, - }) - } - - if account != nil && account.IsBedrock() { - return s.forwardBedrock(ctx, c, account, parsed, startTime) - } - - // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. - // Always overwrite the cache to prevent stale values from a previous retry with a different account. - if account.Platform == PlatformAnthropic && c != nil { - policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model) - if policy.blockErr != nil { - return nil, policy.blockErr - } - filterSet := policy.filterSet - if filterSet == nil { - filterSet = map[string]struct{}{} - } - c.Set(betaPolicyFilterSetKey, filterSet) - } - - body := parsed.Body - reqModel := parsed.Model - reqStream := parsed.Stream - originalModel := reqModel - - // === DEBUG: 打印客户端原始请求(headers + body 摘要)=== - if c != nil { - s.debugLogGatewaySnapshot("CLIENT_ORIGINAL", c.Request.Header, body, map[string]string{ - "account": fmt.Sprintf("%d(%s)", account.ID, account.Name), - "account_type": string(account.Type), - "model": reqModel, - "stream": strconv.FormatBool(reqStream), - }) - } - - // Claude Code 客户端判定:UA 匹配 claude-cli/* 且携带 metadata.user_id。 - // 真正的 Claude Code 客户端自带完整的 system prompt、cache_control 断点和 header, - // 不需要代理做任何 body 级别的 mimicry;强行替换反而会破坏客户端的缓存策略 - // (长 system prompt 被替换为 ~45 tokens 的短 prompt,低于 Anthropic 1024 token - // 最低缓存门槛,导致系统级缓存失效)。 - // - // 对于非 Claude Code 的第三方客户端(opencode 等),仍然走完整 mimicry。 - isClaudeCode := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) - shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode - - if shouldMimicClaudeCode { - // 与 Parrot 对齐:OAuth 账号无条件重写 system(即使客户端已发了 Claude Code - // 风格的 system prompt)。原因:第三方工具(opencode 等)会发 "You are Claude - // Code..." system prompt 但缺少 billing attribution block,导致 Anthropic - // 检测到"有 CC prompt 但无 billing block"的不一致而判为 third-party。 - // Parrot 的 transform_request 从不检查客户端 system 内容,直接覆盖。 - systemRewritten := false - if !strings.Contains(strings.ToLower(reqModel), "haiku") { - body = rewriteSystemForNonClaudeCode(body, parsed.System) - systemRewritten = true - } - - // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为); - // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。 - // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。 - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} - if s.identityService != nil { - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) - if err == nil && fp != nil { - // metadata 透传开启时跳过 metadata 注入 - _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx) - if !mimicMPT { - if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { - normalizeOpts.injectMetadata = true - normalizeOpts.metadataUserID = metadataUserID - } - } - } - } - - body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) - - // D/E/F: 可选 messages cache 策略 + 工具名混淆 + tools[-1] 断点 - // 与 forward_as_chat_completions / forward_as_responses 路径对齐, - // 原生 /v1/messages 路径也走同一套可配置字段级改写。 - body = s.rewriteMessageCacheControlIfEnabled(ctx, body) - if rw := buildToolNameRewriteFromBody(body); rw != nil { - body = applyToolNameRewriteToBody(body, rw) - c.Set(toolNameRewriteKey, rw) - } else { - body = applyToolsLastCacheBreakpoint(body) - } - } - - // 强制执行 cache_control 块数量限制(最多 4 个) - body = enforceCacheControlLimit(body) - - // 应用模型映射: - // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 - // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) - mappedModel := reqModel - mappingSource := "" - if account.Type == AccountTypeAPIKey { - mappedModel = account.GetMappedModel(reqModel) - if mappedModel != reqModel { - mappingSource = "account" - } - } - if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { - if candidate, matched := account.ResolveMappedModel(reqModel); matched { - mappedModel = candidate - mappingSource = "account" - } else { - normalized := normalizeVertexAnthropicModelID(claude.NormalizeModelID(reqModel)) - if normalized != reqModel { - mappedModel = normalized - mappingSource = "vertex" - } - } - } - if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { - normalized := claude.NormalizeModelID(reqModel) - if normalized != reqModel { - mappedModel = normalized - mappingSource = "prefix" - } - } - if mappedModel != reqModel { - // 替换请求体中的模型名 - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) - } - - if s.shouldInjectAnthropicCacheTTL1h(ctx, account) { - body = injectAnthropicCacheControlTTL1h(body) - } - - // 获取凭证 - token, tokenType, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - - // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { - proxyURL = account.Proxy.URL() - } - } - - // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) - tlsProfile := s.tlsFPProfileService.ResolveTLSProfile(account) - - // 调试日志:记录即将转发的账号信息 - logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", - account.ID, account.Name, account.Platform, account.Type, tlsProfile, proxyURL) - // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. - body = StripEmptyTextBlocks(body) - - // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 - setOpsUpstreamRequestBody(c, body) - - // 重试循环 - var resp *http.Response - retryStart := time.Now() - for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) - upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) - releaseUpstreamCtx() - if err != nil { - return nil, err - } - - // 发送请求 - resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, tlsProfile) - if err != nil { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Kind: "request_error", - Message: safeErr, - }) - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - return nil, fmt.Errorf("upstream request failed: %s", safeErr) - } - - // 优先检测thinking block签名错误(400)并重试一次 - if resp.StatusCode == 400 { - respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if readErr == nil { - _ = resp.Body.Close() - - if s.shouldRectifySignatureError(ctx, account, respBody) { - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Kind: "signature_error", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - - looksLikeToolSignatureError := func(msg string) bool { - m := strings.ToLower(msg) - return strings.Contains(m, "tool_use") || - strings.Contains(m, "tool_result") || - strings.Contains(m, "functioncall") || - strings.Contains(m, "function_call") || - strings.Contains(m, "functionresponse") || - strings.Contains(m, "function_response") - } - - // 避免在重试预算已耗尽时再发起额外请求 - if time.Since(retryStart) >= maxRetryElapsed { - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - break - } - logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID) - - // Conservative two-stage fallback: - // 1) Disable thinking + thinking->text (preserve content) - // 2) Only if upstream still errors AND error message points to tool/function signature issues: - // also downgrade tool_use/tool_result blocks to text. - - filteredBody := FilterThinkingBlocksForRetry(body) - retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream) - retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) - releaseRetryCtx() - if buildErr == nil { - retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, tlsProfile) - if retryErr == nil { - if retryResp.StatusCode < 400 { - logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID) - resp = retryResp - break - } - - retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) - _ = retryResp.Body.Close() - if retryReadErr == nil && retryResp.StatusCode == 400 && s.isSignatureErrorPattern(ctx, account, retryRespBody) { - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: retryResp.StatusCode, - UpstreamRequestID: retryResp.Header.Get("x-request-id"), - UpstreamURL: safeUpstreamURL(retryReq.URL.String()), - Kind: "signature_retry_thinking", - Message: extractUpstreamErrorMessage(retryRespBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(retryRespBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - msg2 := extractUpstreamErrorMessage(retryRespBody) - if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { - logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) - filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) - retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream) - retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) - releaseRetryCtx2() - if buildErr2 == nil { - retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, tlsProfile) - if retryErr2 == nil { - resp = retryResp2 - break - } - if retryResp2 != nil && retryResp2.Body != nil { - _ = retryResp2.Body.Close() - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - UpstreamURL: safeUpstreamURL(retryReq2.URL.String()), - Kind: "signature_retry_tools_request_error", - Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), - }) - logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) - } else { - logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) - } - } - } - - // Fall back to the original retry response context. - resp = &http.Response{ - StatusCode: retryResp.StatusCode, - Header: retryResp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(retryRespBody)), - } - break - } - if retryResp != nil && retryResp.Body != nil { - _ = retryResp.Body.Close() - } - logger.LegacyPrintf("service.gateway", "Account %d: signature error retry failed: %v", account.ID, retryErr) - } else { - logger.LegacyPrintf("service.gateway", "Account %d: signature error retry build request failed: %v", account.ID, buildErr) - } - - // Retry failed: restore original response body and continue handling. - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - break - } - // 不是签名错误(或整流器已关闭),继续检查 budget 约束 - errMsg := extractUpstreamErrorMessage(respBody) - if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Kind: "budget_constraint_error", - Message: errMsg, - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - - rectifiedBody, applied := RectifyThinkingBudget(body) - if applied && time.Since(retryStart) < maxRetryElapsed { - logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) - budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream) - budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) - releaseBudgetRetryCtx() - if buildErr == nil { - budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, tlsProfile) - if retryErr == nil { - resp = budgetRetryResp - break - } - if budgetRetryResp != nil && budgetRetryResp.Body != nil { - _ = budgetRetryResp.Body.Close() - } - logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr) - } else { - logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr) - } - } - } - - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - } - } - - // 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了) - if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { - if attempt < maxRetryAttempts { - elapsed := time.Since(retryStart) - if elapsed >= maxRetryElapsed { - break - } - - delay := retryBackoffDelay(attempt) - remaining := maxRetryElapsed - elapsed - if delay > remaining { - delay = remaining - } - if delay <= 0 { - break - } - - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Kind: "retry", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", - account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) - if err := sleepWithContext(ctx, delay); err != nil { - return nil, err - } - continue - } - // 最后一次尝试也失败,跳出循环处理重试耗尽 - break - } - - // 不需要重试(成功或不可重试的错误),跳出循环 - // DEBUG: 输出响应 headers(用于检测 rate limit 信息) - if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { - logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) - for k, v := range resp.Header { - logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) - } - } - break - } - if resp == nil || resp.Body == nil { - return nil, errors.New("upstream request failed: empty response") - } - defer func() { _ = resp.Body.Close() }() - - // 处理重试耗尽的情况 - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - // 调试日志:打印重试耗尽后的错误响应 - logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", - account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - - s.handleRetryExhaustedSideEffects(ctx, resp, account) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry_exhausted_failover", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), - } - } - return s.handleRetryExhaustedError(ctx, resp, c, account) - } - - // 处理可切换账号的错误 - if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - // 调试日志:打印上游错误响应 - logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", - account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - - s.handleFailoverSideEffects(ctx, resp, account) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "failover", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), - } - } - if resp.StatusCode >= 400 { - // 可选:对部分 400 触发 failover(默认关闭以保持语义) - if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { - respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if readErr != nil { - // ReadAll failed, fall back to normal error handling without consuming the stream - return s.handleErrorResponse(ctx, resp, c, account) - } - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - if s.shouldFailoverOn400(respBody) { - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "failover_on_400", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - - if s.cfg.Gateway.LogUpstreamErrorBody { - logger.LegacyPrintf("service.gateway", - "Account %d: 400 error, attempting failover: %s", - account.ID, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } else { - logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) - } - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} - } - } - return s.handleErrorResponse(ctx, resp, c, account) - } - - // 处理正常响应 - - // 触发上游接受回调(提前释放串行锁,不等流完成) - if parsed.OnUpstreamAccepted != nil { - parsed.OnUpstreamAccepted() - } - - var usage *ClaudeUsage - var firstTokenMs *int - var clientDisconnect bool - if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode) - if err != nil { - if err.Error() == "have error in stream" { - return nil, &UpstreamFailoverError{ - StatusCode: 403, - } - } - return nil, err - } - usage = streamResult.usage - firstTokenMs = streamResult.firstTokenMs - clientDisconnect = streamResult.clientDisconnect - } else { - usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) - if err != nil { - return nil, err - } - } - - return &ForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 - UpstreamModel: mappedModel, - Stream: reqStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ClientDisconnect: clientDisconnect, - }, nil -} - -type anthropicPassthroughForwardInput struct { - Body []byte - RequestModel string - OriginalModel string - RequestStream bool - StartTime time.Time -} - -func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( - ctx context.Context, - c *gin.Context, - account *Account, - body []byte, - reqModel string, - originalModel string, - reqStream bool, - startTime time.Time, -) (*ForwardResult, error) { - return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ - Body: body, - RequestModel: reqModel, - OriginalModel: originalModel, - RequestStream: reqStream, - StartTime: startTime, - }) -} - -func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( - ctx context.Context, - c *gin.Context, - account *Account, - input anthropicPassthroughForwardInput, -) (*ForwardResult, error) { - token, tokenType, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - if tokenType != "apikey" { - return nil, fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) - } - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", - account.ID, account.Name, input.RequestModel, input.RequestStream) - - if c != nil { - c.Set("anthropic_passthrough", true) - } - // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. - input.Body = StripEmptyTextBlocks(input.Body) - - // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 - setOpsUpstreamRequestBody(c, input.Body) - - var resp *http.Response - retryStart := time.Now() - for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream) - upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token) - releaseUpstreamCtx() - if err != nil { - return nil, err - } - - resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) - if err != nil { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Passthrough: true, - Kind: "request_error", - Message: safeErr, - }) - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - return nil, fmt.Errorf("upstream request failed: %s", safeErr) - } - - // 透传分支禁止 400 请求体降级重试(该重试会改写请求体) - if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { - if attempt < maxRetryAttempts { - elapsed := time.Since(retryStart) - if elapsed >= maxRetryElapsed { - break - } - - delay := retryBackoffDelay(attempt) - remaining := maxRetryElapsed - elapsed - if delay > remaining { - delay = remaining - } - if delay <= 0 { - break - } - - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Passthrough: true, - Kind: "retry", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - logger.LegacyPrintf("service.gateway", "Anthropic passthrough account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", - account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) - if err := sleepWithContext(ctx, delay); err != nil { - return nil, err - } - continue - } - break - } - - break - } - if resp == nil || resp.Body == nil { - return nil, errors.New("upstream request failed: empty response") - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", - account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - - s.handleRetryExhaustedSideEffects(ctx, resp, account) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Passthrough: true, - Kind: "retry_exhausted_failover", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), - } - } - return s.handleRetryExhaustedError(ctx, resp, c, account) - } - - if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", - account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - - s.handleFailoverSideEffects(ctx, resp, account) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Passthrough: true, - Kind: "failover", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), - } - } - - if resp.StatusCode >= 400 { - return s.handleErrorResponse(ctx, resp, c, account) - } - - var usage *ClaudeUsage - var firstTokenMs *int - var clientDisconnect bool - if input.RequestStream { - streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel) - if err != nil { - return nil, err - } - usage = streamResult.usage - firstTokenMs = streamResult.firstTokenMs - clientDisconnect = streamResult.clientDisconnect - } else { - usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account) - if err != nil { - return nil, err - } - } - if usage == nil { - usage = &ClaudeUsage{} - } - - return &ForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: input.OriginalModel, - UpstreamModel: input.RequestModel, - Stream: input.RequestStream, - Duration: time.Since(input.StartTime), - FirstTokenMs: firstTokenMs, - ClientDisconnect: clientDisconnect, - }, nil -} - -func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( - ctx context.Context, - c *gin.Context, - account *Account, - body []byte, - token string, -) (*http.Request, error) { - targetURL := claudeAPIURL - baseURL := account.GetBaseURL() - if baseURL != "" { - validatedURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return nil, err - } - targetURL = validatedURL + "/v1/messages?beta=true" - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - lowerKey := strings.ToLower(strings.TrimSpace(key)) - if !allowedHeaders[lowerKey] { - continue - } - wireKey := resolveWireCasing(key) - for _, v := range values { - addHeaderRaw(req.Header, wireKey, v) - } - } - } - - // 覆盖入站鉴权残留,并注入上游认证 - req.Header.Del("authorization") - req.Header.Del("x-api-key") - req.Header.Del("x-goog-api-key") - req.Header.Del("cookie") - setHeaderRaw(req.Header, "x-api-key", token) - - if getHeaderRaw(req.Header, "content-type") == "" { - setHeaderRaw(req.Header, "content-type", "application/json") - } - if getHeaderRaw(req.Header, "anthropic-version") == "" { - setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") - } - - return req, nil -} - -func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( - ctx context.Context, - resp *http.Response, - c *gin.Context, - account *Account, - startTime time.Time, - model string, -) (*streamingResult, error) { - if s.rateLimitService != nil { - s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - } - - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - - contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) - if contentType == "" { - contentType = "text/event-stream" - } - c.Header("Content-Type", contentType) - if c.Writer.Header().Get("Cache-Control") == "" { - c.Header("Cache-Control", "no-cache") - } - if c.Writer.Header().Get("Connection") == "" { - c.Header("Connection", "keep-alive") - } - c.Header("X-Accel-Buffering", "no") - if v := resp.Header.Get("x-request-id"); v != "" { - c.Header("x-request-id", v) - } - - w := c.Writer - flusher, ok := w.(http.Flusher) - if !ok { - return nil, errors.New("streaming not supported") - } - - usage := &ClaudeUsage{} - var firstTokenMs *int - clientDisconnected := false - sawTerminalEvent := false - - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanBuf := getSSEScannerBuf64K() - scanner.Buffer(scanBuf[:0], maxLineSize) - - type scanEvent struct { - line string - err error - } - events := make(chan scanEvent, 16) - done := make(chan struct{}) - sendEvent := func(ev scanEvent) bool { - select { - case events <- ev: - return true - case <-done: - return false - } - } - var lastReadAt int64 - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func(scanBuf *sseScannerBuf64K) { - defer putSSEScannerBuf64K(scanBuf) - defer close(events) - for scanner.Scan() { - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - if !sendEvent(scanEvent{line: scanner.Text()}) { - return - } - } - if err := scanner.Err(); err != nil { - _ = sendEvent(scanEvent{err: err}) - } - }(scanBuf) - defer close(done) - - streamInterval := time.Duration(0) - if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { - streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second - } - var intervalTicker *time.Ticker - if streamInterval > 0 { - intervalTicker = time.NewTicker(streamInterval) - defer intervalTicker.Stop() - } - var intervalCh <-chan time.Time - if intervalTicker != nil { - intervalCh = intervalTicker.C - } - - for { - select { - case ev, ok := <-events: - if !ok { - if !clientDisconnected { - // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 - flusher.Flush() - } - if !sawTerminalEvent { - if clientDisconnected && streamInterval > 0 { - lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) - if time.Since(lastRead) >= streamInterval { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") - } - } - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") - } - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil - } - if ev.err != nil { - if sawTerminalEvent { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil - } - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) - } - if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) - } - if errors.Is(ev.err, bufio.ErrTooLong) { - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err - } - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) - } - - line := ev.line - if data, ok := extractAnthropicSSEDataLine(line); ok { - trimmed := strings.TrimSpace(data) - if anthropicStreamEventIsTerminal("", trimmed) { - sawTerminalEvent = true - } - if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsagePassthrough(data, usage) - } else { - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") { - sawTerminalEvent = true - } - } - - if !clientDisconnected { - restored := string(reverseToolNamesIfPresent(c, []byte(line))) - if _, err := io.WriteString(w, restored); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) - } else if _, err := io.WriteString(w, "\n"); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) - } else if line == "" { - // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 - flusher.Flush() - } - } - - case <-intervalCh: - lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) - if time.Since(lastRead) < streamInterval { - continue - } - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") - } - logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) - if s.rateLimitService != nil { - s.rateLimitService.HandleStreamTimeout(ctx, account, model) - } - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") - } - } -} - -func extractAnthropicSSEDataLine(line string) (string, bool) { - if !strings.HasPrefix(line, "data:") { - return "", false - } - start := len("data:") - for start < len(line) { - if line[start] != ' ' && line[start] != '\t' { - break - } - start++ - } - return line[start:], true -} - -func (s *GatewayService) parseSSEUsagePassthrough(data string, usage *ClaudeUsage) { - if usage == nil || data == "" || data == "[DONE]" { - return - } - - parsed := gjson.Parse(data) - switch parsed.Get("type").String() { - case "message_start": - msgUsage := parsed.Get("message.usage") - if msgUsage.Exists() { - usage.InputTokens = int(msgUsage.Get("input_tokens").Int()) - usage.CacheCreationInputTokens = int(msgUsage.Get("cache_creation_input_tokens").Int()) - usage.CacheReadInputTokens = int(msgUsage.Get("cache_read_input_tokens").Int()) - - // 保持与通用解析一致:message_start 允许覆盖 5m/1h 明细(包括 0)。 - cc5m := msgUsage.Get("cache_creation.ephemeral_5m_input_tokens") - cc1h := msgUsage.Get("cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() || cc1h.Exists() { - usage.CacheCreation5mTokens = int(cc5m.Int()) - usage.CacheCreation1hTokens = int(cc1h.Int()) - } - } - case "message_delta": - deltaUsage := parsed.Get("usage") - if deltaUsage.Exists() { - if v := deltaUsage.Get("input_tokens").Int(); v > 0 { - usage.InputTokens = int(v) - } - if v := deltaUsage.Get("output_tokens").Int(); v > 0 { - usage.OutputTokens = int(v) - } - if v := deltaUsage.Get("cache_creation_input_tokens").Int(); v > 0 { - usage.CacheCreationInputTokens = int(v) - } - if v := deltaUsage.Get("cache_read_input_tokens").Int(); v > 0 { - usage.CacheReadInputTokens = int(v) - } - - cc5m := deltaUsage.Get("cache_creation.ephemeral_5m_input_tokens") - cc1h := deltaUsage.Get("cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() && cc5m.Int() > 0 { - usage.CacheCreation5mTokens = int(cc5m.Int()) - } - if cc1h.Exists() && cc1h.Int() > 0 { - usage.CacheCreation1hTokens = int(cc1h.Int()) - } - } - } - - if usage.CacheReadInputTokens == 0 { - if cached := parsed.Get("message.usage.cached_tokens").Int(); cached > 0 { - usage.CacheReadInputTokens = int(cached) - } - if cached := parsed.Get("usage.cached_tokens").Int(); usage.CacheReadInputTokens == 0 && cached > 0 { - usage.CacheReadInputTokens = int(cached) - } - } - if usage.CacheCreationInputTokens == 0 { - cc5m := parsed.Get("message.usage.cache_creation.ephemeral_5m_input_tokens").Int() - cc1h := parsed.Get("message.usage.cache_creation.ephemeral_1h_input_tokens").Int() - if cc5m == 0 && cc1h == 0 { - cc5m = parsed.Get("usage.cache_creation.ephemeral_5m_input_tokens").Int() - cc1h = parsed.Get("usage.cache_creation.ephemeral_1h_input_tokens").Int() - } - total := cc5m + cc1h - if total > 0 { - usage.CacheCreationInputTokens = int(total) - } - } -} - -func parseClaudeUsageFromResponseBody(body []byte) *ClaudeUsage { - usage := &ClaudeUsage{} - if len(body) == 0 { - return usage - } - - parsed := gjson.ParseBytes(body) - usageNode := parsed.Get("usage") - if !usageNode.Exists() { - return usage - } - - usage.InputTokens = int(usageNode.Get("input_tokens").Int()) - usage.OutputTokens = int(usageNode.Get("output_tokens").Int()) - usage.CacheCreationInputTokens = int(usageNode.Get("cache_creation_input_tokens").Int()) - usage.CacheReadInputTokens = int(usageNode.Get("cache_read_input_tokens").Int()) - - cc5m := usageNode.Get("cache_creation.ephemeral_5m_input_tokens").Int() - cc1h := usageNode.Get("cache_creation.ephemeral_1h_input_tokens").Int() - if cc5m > 0 || cc1h > 0 { - usage.CacheCreation5mTokens = int(cc5m) - usage.CacheCreation1hTokens = int(cc1h) - } - if usage.CacheCreationInputTokens == 0 && (cc5m > 0 || cc1h > 0) { - usage.CacheCreationInputTokens = int(cc5m + cc1h) - } - if usage.CacheReadInputTokens == 0 { - if cached := usageNode.Get("cached_tokens").Int(); cached > 0 { - usage.CacheReadInputTokens = int(cached) - } - } - return usage -} - -func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( - ctx context.Context, - resp *http.Response, - c *gin.Context, - account *Account, -) (*ClaudeUsage, error) { - if s.rateLimitService != nil { - s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - } - - body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) - if err != nil { - return nil, err - } - - usage := parseClaudeUsageFromResponseBody(body) - - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) - if contentType == "" { - contentType = "application/json" - } - body = reverseToolNamesIfPresent(c, body) - c.Data(resp.StatusCode, contentType, body) - return usage, nil -} - -func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { - if dst == nil || src == nil { - return - } - if filter != nil { - responseheaders.WriteFilteredHeaders(dst, src, filter) - return - } - if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { - dst.Set("Content-Type", v) - } - if v := strings.TrimSpace(src.Get("x-request-id")); v != "" { - dst.Set("x-request-id", v) - } -} - -// forwardBedrock 转发请求到 AWS Bedrock -func (s *GatewayService) forwardBedrock( - ctx context.Context, - c *gin.Context, - account *Account, - parsed *ParsedRequest, - startTime time.Time, -) (*ForwardResult, error) { - reqModel := parsed.Model - reqStream := parsed.Stream - body := parsed.Body - - region := bedrockRuntimeRegion(account) - mappedModel, ok := ResolveBedrockModelID(account, reqModel) - if !ok { - return nil, fmt.Errorf("unsupported bedrock model: %s", reqModel) - } - if mappedModel != reqModel { - logger.LegacyPrintf("service.gateway", "[Bedrock] Model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) - } - - betaHeader := "" - if c != nil && c.Request != nil { - betaHeader = c.GetHeader("anthropic-beta") - } - - // 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control) - betaTokens, err := s.resolveBedrockBetaTokensForRequest(ctx, account, betaHeader, body, mappedModel) - if err != nil { - return nil, err - } - - bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens) - if err != nil { - return nil, fmt.Errorf("prepare bedrock request body: %w", err) - } - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - logger.LegacyPrintf("service.gateway", "[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v", - account.ID, account.Name, reqModel, mappedModel, reqStream) - - // 根据账号类型选择认证方式 - var signer *BedrockSigner - var bedrockAPIKey string - if account.IsBedrockAPIKey() { - bedrockAPIKey = account.GetCredential("api_key") - if bedrockAPIKey == "" { - return nil, fmt.Errorf("api_key not found in bedrock credentials") - } - } else { - signer, err = NewBedrockSignerFromAccount(account) - if err != nil { - return nil, fmt.Errorf("create bedrock signer: %w", err) - } - } - - // 执行上游请求(含重试) - resp, err := s.executeBedrockUpstream(ctx, c, account, bedrockBody, mappedModel, region, reqStream, signer, bedrockAPIKey, proxyURL) - if err != nil { - return nil, err - } - defer func() { _ = resp.Body.Close() }() - - // 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id, - // 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。 - if awsReqID := resp.Header.Get("x-amzn-requestid"); awsReqID != "" && resp.Header.Get("x-request-id") == "" { - resp.Header.Set("x-request-id", awsReqID) - } - - // 错误/failover 处理 - if resp.StatusCode >= 400 { - return s.handleBedrockUpstreamErrors(ctx, resp, c, account) - } - - // 响应处理 - var usage *ClaudeUsage - var firstTokenMs *int - var clientDisconnect bool - if reqStream { - streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel) - if err != nil { - return nil, err - } - usage = streamResult.usage - firstTokenMs = streamResult.firstTokenMs - clientDisconnect = streamResult.clientDisconnect - } else { - usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account) - if err != nil { - return nil, err - } - } - if usage == nil { - usage = &ClaudeUsage{} - } - - return &ForwardResult{ - RequestID: resp.Header.Get("x-amzn-requestid"), - Usage: *usage, - Model: reqModel, - UpstreamModel: mappedModel, - Stream: reqStream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ClientDisconnect: clientDisconnect, - }, nil -} - -// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑) -func (s *GatewayService) executeBedrockUpstream( - ctx context.Context, - c *gin.Context, - account *Account, - body []byte, - modelID string, - region string, - stream bool, - signer *BedrockSigner, - apiKey string, - proxyURL string, -) (*http.Response, error) { - var resp *http.Response - var err error - retryStart := time.Now() - for attempt := 1; attempt <= maxRetryAttempts; attempt++ { - var upstreamReq *http.Request - if account.IsBedrockAPIKey() { - upstreamReq, err = s.buildUpstreamRequestBedrockAPIKey(ctx, body, modelID, region, stream, apiKey) - } else { - upstreamReq, err = s.buildUpstreamRequestBedrock(ctx, body, modelID, region, stream, signer) - } - if err != nil { - return nil, err - } - - resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, nil) - if err != nil { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Kind: "request_error", - Message: safeErr, - }) - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - return nil, fmt.Errorf("upstream request failed: %s", safeErr) - } - - if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { - if attempt < maxRetryAttempts { - elapsed := time.Since(retryStart) - if elapsed >= maxRetryElapsed { - break - } - - delay := retryBackoffDelay(attempt) - remaining := maxRetryElapsed - elapsed - if delay > remaining { - delay = remaining - } - if delay <= 0 { - break - } - - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Kind: "retry", - Message: extractUpstreamErrorMessage(respBody), - Detail: func() string { - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) - } - return "" - }(), - }) - logger.LegacyPrintf("service.gateway", "[Bedrock] account %d: upstream error %d, retry %d/%d after %v", - account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay) - if err := sleepWithContext(ctx, delay); err != nil { - return nil, err - } - continue - } - break - } - - break - } - if resp == nil || resp.Body == nil { - return nil, errors.New("upstream request failed: empty response") - } - return resp, nil -} - -// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应) -func (s *GatewayService) handleBedrockUpstreamErrors( - ctx context.Context, - resp *http.Response, - c *gin.Context, - account *Account, -) (*ForwardResult, error) { - // retry exhausted + failover - if s.shouldRetryUpstreamError(account, resp.StatusCode) { - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - logger.LegacyPrintf("service.gateway", "[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s", - account.ID, account.Name, resp.StatusCode, truncateString(string(respBody), 1000)) - - s.handleRetryExhaustedSideEffects(ctx, resp, account) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - Kind: "retry_exhausted_failover", - Message: extractUpstreamErrorMessage(respBody), - }) - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), - } - } - return s.handleRetryExhaustedError(ctx, resp, c, account) - } - - // non-retryable failover - if s.shouldFailoverUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - s.handleFailoverSideEffects(ctx, resp, account) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - Kind: "failover", - Message: extractUpstreamErrorMessage(respBody), - }) - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), - } - } - - // other errors - return s.handleErrorResponse(ctx, resp, c, account) -} - -// buildUpstreamRequestBedrock 构建 Bedrock 上游请求 -func (s *GatewayService) buildUpstreamRequestBedrock( - ctx context.Context, - body []byte, - modelID string, - region string, - stream bool, - signer *BedrockSigner, -) (*http.Request, error) { - targetURL := BuildBedrockURL(region, modelID, stream) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - - // SigV4 签名 - if err := signer.SignRequest(ctx, req, body); err != nil { - return nil, fmt.Errorf("sign bedrock request: %w", err) - } - - return req, nil -} - -// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求 -func (s *GatewayService) buildUpstreamRequestBedrockAPIKey( - ctx context.Context, - body []byte, - modelID string, - region string, - stream bool, - apiKey string, -) (*http.Request, error) { - targetURL := BuildBedrockURL(region, modelID, stream) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - - return req, nil -} - -// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应 -// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容 -func (s *GatewayService) handleBedrockNonStreamingResponse( - ctx context.Context, - resp *http.Response, - c *gin.Context, - account *Account, -) (*ClaudeUsage, error) { - body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) - if err != nil { - return nil, err - } - - // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 - // 并移除该字段避免透传给客户端 - body = transformBedrockInvocationMetrics(body) - - usage := parseClaudeUsageFromResponseBody(body) - - c.Header("Content-Type", "application/json") - if v := resp.Header.Get("x-amzn-requestid"); v != "" { - c.Header("x-request-id", v) - } - c.Data(resp.StatusCode, "application/json", body) - return usage, nil -} - -func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { - if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { - return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream) - } - - // 确定目标URL - targetURL := claudeAPIURL - if account.Type == AccountTypeAPIKey { - baseURL := account.GetBaseURL() - if baseURL != "" { - validatedURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return nil, err - } - targetURL = validatedURL + "/v1/messages?beta=true" - } - } else if account.IsCustomBaseURLEnabled() { - customURL := account.GetCustomBaseURL() - if customURL == "" { - return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) - } - validatedURL, err := s.validateUpstreamBaseURL(customURL) - if err != nil { - return nil, err - } - targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) - } - - clientHeaders := http.Header{} - if c != nil && c.Request != nil { - clientHeaders = c.Request.Header - } - - // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) - var fingerprint *Fingerprint - enableFP, enableMPT, enableCCH := true, false, false - if s.settingService != nil { - enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx) - } - if account.IsOAuth() && s.identityService != nil { - // 1. 获取或创建指纹(包含随机生成的ClientID) - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) - if err != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) - // 失败时降级为透传原始headers - } else { - if enableFP { - fingerprint = fp - } - - // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) - // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - // 当 metadata 透传开启时跳过重写 - if !enableMPT { - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { - body = newBody - } - } - } - } - } - - // 同步 billing header cc_version 与实际发送的 User-Agent 版本 - if fingerprint != nil { - body = syncBillingHeaderVersion(body, fingerprint.UserAgent) - } - // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后) - if enableCCH { - body = signBillingHeaderCCH(body) - } - - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - // 设置认证头(保持原始大小写) - if tokenType == "oauth" { - setHeaderRaw(req.Header, "authorization", "Bearer "+token) - } else { - setHeaderRaw(req.Header, "x-api-key", token) - } - - // 白名单透传 headers - // OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。 - // Parrot 的 build_upstream_headers 只发 9 个精确 header,不透传任何客户端 header。 - // 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent / - // x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。 - if tokenType != "oauth" || !mimicClaudeCode { - for key, values := range clientHeaders { - lowerKey := strings.ToLower(key) - if allowedHeaders[lowerKey] { - wireKey := resolveWireCasing(key) - for _, v := range values { - addHeaderRaw(req.Header, wireKey, v) - } - } - } - } - - // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) - if fingerprint != nil { - s.identityService.ApplyFingerprint(req, fingerprint) - } - - // 确保必要的headers存在(保持原始大小写) - if getHeaderRaw(req.Header, "content-type") == "" { - setHeaderRaw(req.Header, "content-type", "application/json") - } - if getHeaderRaw(req.Header, "anthropic-version") == "" { - setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") - } - if tokenType == "oauth" { - applyClaudeOAuthHeaderDefaults(req) - } - - // Build effective drop set: merge static defaults with dynamic beta policy filter rules - policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) - effectiveDropSet := mergeDropSets(policyFilterSet) - - // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) - if tokenType == "oauth" { - if mimicClaudeCode { - // 非 Claude Code 客户端:按 opencode 的策略处理: - // - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app) - // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 - applyClaudeCodeMimicHeaders(req, reqStream) - - incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - // Claude Code OAuth credentials are scoped to Claude Code. - // Non-haiku models MUST include claude-code beta for Anthropic to recognize - // this as a legitimate Claude Code request; without it, the request is - // rejected as third-party ("out of extra usage"). - // Haiku models are exempt from third-party detection and don't need it. - requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - if !strings.Contains(strings.ToLower(modelID), "haiku") { - requiredBetas = claude.FullClaudeCodeMimicryBetas() - } - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) - } else { - // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta - clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") - setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) - } - } else { - // API-key accounts: apply beta policy filter to strip controlled tokens - if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { - setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet)) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { - // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { - setHeaderRaw(req.Header, "anthropic-beta", beta) - } - } - } - } - - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } - } - } - - // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === - s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ - "url": req.URL.String(), - "token_type": tokenType, - "mimic_claude_code": strconv.FormatBool(mimicClaudeCode), - "fingerprint_applied": strconv.FormatBool(fingerprint != nil), - "enable_fp": strconv.FormatBool(enableFP), - "enable_mpt": strconv.FormatBool(enableMPT), - }) - - // Always capture a compact fingerprint line for later error diagnostics. - // We only print it when needed (or when the explicit debug flag is enabled). - if c != nil && tokenType == "oauth" { - c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) - } - if s.debugClaudeMimicEnabled() { - logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) - } - - return req, nil -} - -func (s *GatewayService) buildUpstreamRequestAnthropicVertex( - ctx context.Context, - c *gin.Context, - account *Account, - body []byte, - token string, - modelID string, - reqStream bool, -) (*http.Request, error) { - vertexBody, err := buildVertexAnthropicRequestBody(body) - if err != nil { - return nil, err - } - setOpsUpstreamRequestBody(c, vertexBody) - fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream) - if err != nil { - return nil, err - } - req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody)) - if err != nil { - return nil, err - } - - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - lowerKey := strings.ToLower(strings.TrimSpace(key)) - if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" { - continue - } - wireKey := resolveWireCasing(key) - for _, v := range values { - addHeaderRaw(req.Header, wireKey, v) - } - } - } - - req.Header.Del("authorization") - req.Header.Del("x-api-key") - req.Header.Del("x-goog-api-key") - req.Header.Del("cookie") - req.Header.Del("anthropic-version") - setHeaderRaw(req.Header, "authorization", "Bearer "+token) - setHeaderRaw(req.Header, "content-type", "application/json") - - s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{ - "url": req.URL.String(), - "token_type": "service_account", - "model": modelID, - "stream": strconv.FormatBool(reqStream), - }) - - return req, nil -} - -// getBetaHeader 处理anthropic-beta header -// 对于OAuth账号,需要确保包含oauth-2025-04-20 -func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { - // 如果客户端传了anthropic-beta - if clientBetaHeader != "" { - // 已包含oauth beta则直接返回 - if strings.Contains(clientBetaHeader, claude.BetaOAuth) { - return clientBetaHeader - } - - // 需要添加oauth beta - parts := strings.Split(clientBetaHeader, ",") - for i, p := range parts { - parts[i] = strings.TrimSpace(p) - } - - // 在claude-code-20250219后面插入oauth beta - claudeCodeIdx := -1 - for i, p := range parts { - if p == claude.BetaClaudeCode { - claudeCodeIdx = i - break - } - } - - if claudeCodeIdx >= 0 { - // 在claude-code后面插入 - newParts := make([]string, 0, len(parts)+1) - newParts = append(newParts, parts[:claudeCodeIdx+1]...) - newParts = append(newParts, claude.BetaOAuth) - newParts = append(newParts, parts[claudeCodeIdx+1:]...) - return strings.Join(newParts, ",") - } - - // 没有claude-code,放在第一位 - return claude.BetaOAuth + "," + clientBetaHeader - } - - // 客户端没传,根据模型生成 - // haiku 模型不需要 claude-code beta - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.HaikuBetaHeader - } - - return claude.DefaultBetaHeader -} - -func requestNeedsBetaFeatures(body []byte) bool { - tools := gjson.GetBytes(body, "tools") - if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { - return true - } - thinkingType := gjson.GetBytes(body, "thinking.type").String() - if strings.EqualFold(thinkingType, "enabled") || strings.EqualFold(thinkingType, "adaptive") { - return true - } - return false -} - -func defaultAPIKeyBetaHeader(body []byte) string { - modelID := gjson.GetBytes(body, "model").String() - if strings.Contains(strings.ToLower(modelID), "haiku") { - return claude.APIKeyHaikuBetaHeader - } - return claude.APIKeyBetaHeader -} - -func applyClaudeOAuthHeaderDefaults(req *http.Request) { - if req == nil { - return - } - if getHeaderRaw(req.Header, "Accept") == "" { - setHeaderRaw(req.Header, "Accept", "application/json") - } - for key, value := range claude.DefaultHeaders { - if value == "" { - continue - } - if getHeaderRaw(req.Header, key) == "" { - setHeaderRaw(req.Header, resolveWireCasing(key), value) - } - } -} - -func mergeAnthropicBeta(required []string, incoming string) string { - seen := make(map[string]struct{}, len(required)+8) - out := make([]string, 0, len(required)+8) - - add := func(v string) { - v = strings.TrimSpace(v) - if v == "" { - return - } - if _, ok := seen[v]; ok { - return - } - seen[v] = struct{}{} - out = append(out, v) - } - - for _, r := range required { - add(r) - } - for _, p := range strings.Split(incoming, ",") { - add(p) - } - return strings.Join(out, ",") -} - -func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string { - merged := mergeAnthropicBeta(required, incoming) - if merged == "" || len(drop) == 0 { - return merged - } - out := make([]string, 0, 8) - for _, p := range strings.Split(merged, ",") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if _, ok := drop[p]; ok { - continue - } - out = append(out, p) - } - return strings.Join(out, ",") -} - -// stripBetaTokens removes the given beta tokens from a comma-separated header value. -func stripBetaTokens(header string, tokens []string) string { - if header == "" || len(tokens) == 0 { - return header - } - return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens)) -} - -func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { - if header == "" || len(drop) == 0 { - return header - } - parts := strings.Split(header, ",") - out := make([]string, 0, len(parts)) - for _, p := range parts { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if _, ok := drop[p]; ok { - continue - } - out = append(out, p) - } - if len(out) == len(parts) { - return header // no change, avoid allocation - } - return strings.Join(out, ",") -} - -// BetaBlockedError indicates a request was blocked by a beta policy rule. -type BetaBlockedError struct { - Message string -} - -func (e *BetaBlockedError) Error() string { return e.Message } - -// betaPolicyResult holds the evaluated result of beta policy rules for a single request. -type betaPolicyResult struct { - blockErr *BetaBlockedError // non-nil if a block rule matched - filterSet map[string]struct{} // tokens to filter (may be nil) -} - -// evaluateBetaPolicy loads settings once and evaluates all rules against the given request. -func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult { - if s.settingService == nil { - return betaPolicyResult{} - } - settings, err := s.settingService.GetBetaPolicySettings(ctx) - if err != nil || settings == nil { - return betaPolicyResult{} - } - isOAuth := account.IsOAuth() - isBedrock := account.IsBedrock() - var result betaPolicyResult - for _, rule := range settings.Rules { - if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { - continue - } - effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) - switch effectiveAction { - case BetaPolicyActionBlock: - if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { - msg := effectiveErrMsg - if msg == "" { - msg = "beta feature " + rule.BetaToken + " is not allowed" - } - result.blockErr = &BetaBlockedError{Message: msg} - } - case BetaPolicyActionFilter: - if result.filterSet == nil { - result.filterSet = make(map[string]struct{}) - } - result.filterSet[rule.BetaToken] = struct{}{} - } - } - return result -} - -// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens. -// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation). -func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} { - if len(policySet) == 0 && len(extra) == 0 { - return defaultDroppedBetasSet - } - m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra)) - for t := range defaultDroppedBetasSet { - m[t] = struct{}{} - } - for t := range policySet { - m[t] = struct{}{} - } - for _, t := range extra { - m[t] = struct{}{} - } - return m -} - -// betaPolicyFilterSetKey is the gin.Context key for caching the policy filter set within a request. -const betaPolicyFilterSetKey = "betaPolicyFilterSet" - -// getBetaPolicyFilterSet returns the beta policy filter set, using the gin context cache if available. -// In the /v1/messages path, Forward() evaluates the policy first and caches the result; -// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this -// evaluates on demand (one DB call). -func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} { - if c != nil { - if v, ok := c.Get(betaPolicyFilterSetKey); ok { - if fs, ok := v.(map[string]struct{}); ok { - return fs - } - } - } - return s.evaluateBetaPolicy(ctx, "", account, model).filterSet -} - -// betaPolicyScopeMatches checks whether a rule's scope matches the current account type. -func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { - switch scope { - case BetaPolicyScopeAll: - return true - case BetaPolicyScopeOAuth: - return isOAuth - case BetaPolicyScopeAPIKey: - return !isOAuth && !isBedrock - case BetaPolicyScopeBedrock: - return isBedrock - default: - return true // unknown scope → match all (fail-open) - } -} - -// matchModelWhitelist checks if a model matches any pattern in the whitelist. -// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching. -func matchModelWhitelist(model string, whitelist []string) bool { - for _, pattern := range whitelist { - if matchModelPattern(pattern, model) { - return true - } - } - return false -} - -// resolveRuleAction determines the effective action and error message for a rule given the request model. -// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally. -// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others. -func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) { - if len(rule.ModelWhitelist) == 0 { - return rule.Action, rule.ErrorMessage - } - if matchModelWhitelist(model, rule.ModelWhitelist) { - return rule.Action, rule.ErrorMessage - } - if rule.FallbackAction != "" { - return rule.FallbackAction, rule.FallbackErrorMessage - } - return BetaPolicyActionPass, "" // default fallback: pass (fail-open) -} - -// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. -func droppedBetaSet(extra ...string) map[string]struct{} { - m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) - for t := range defaultDroppedBetasSet { - m[t] = struct{}{} - } - for _, t := range extra { - m[t] = struct{}{} - } - return m -} - -// containsBetaToken checks if a comma-separated header value contains the given token. -func containsBetaToken(header, token string) bool { - if header == "" || token == "" { - return false - } - for _, p := range strings.Split(header, ",") { - if strings.TrimSpace(p) == token { - return true - } - } - return false -} - -func filterBetaTokens(tokens []string, filterSet map[string]struct{}) []string { - if len(tokens) == 0 || len(filterSet) == 0 { - return tokens - } - kept := make([]string, 0, len(tokens)) - for _, token := range tokens { - if _, filtered := filterSet[token]; !filtered { - kept = append(kept, token) - } - } - return kept -} - -func (s *GatewayService) resolveBedrockBetaTokensForRequest( - ctx context.Context, - account *Account, - betaHeader string, - body []byte, - modelID string, -) ([]string, error) { - // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) - policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID) - if policy.blockErr != nil { - return nil, policy.blockErr - } - - // 2. 解析 header + body 自动注入 + Bedrock 转换/过滤 - betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID) - - // 3. 对最终 token 列表再做 block 检查,捕获通过 body 自动注入绕过 header block 的情况。 - // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, - // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → - // 如果不做此检查,block 规则会被绕过。 - if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil { - return nil, blockErr - } - - return filterBetaTokens(betaTokens, policy.filterSet), nil -} - -// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 -// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 -func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError { - if s.settingService == nil || len(tokens) == 0 { - return nil - } - settings, err := s.settingService.GetBetaPolicySettings(ctx) - if err != nil || settings == nil { - return nil - } - isOAuth := account.IsOAuth() - isBedrock := account.IsBedrock() - tokenSet := buildBetaTokenSet(tokens) - for _, rule := range settings.Rules { - effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) - if effectiveAction != BetaPolicyActionBlock { - continue - } - if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { - continue - } - if _, present := tokenSet[rule.BetaToken]; present { - msg := effectiveErrMsg - if msg == "" { - msg = "beta feature " + rule.BetaToken + " is not allowed" - } - return &BetaBlockedError{Message: msg} - } - } - return nil -} - -func buildBetaTokenSet(tokens []string) map[string]struct{} { - m := make(map[string]struct{}, len(tokens)) - for _, t := range tokens { - if t == "" { - continue - } - m[t] = struct{}{} - } - return m -} - -var defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) - -// applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. -// This mirrors opencode-anthropic-auth behavior: do not trust downstream -// headers when using Claude Code-scoped OAuth credentials. -func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { - if req == nil { - return - } - // Start with the standard defaults (fill missing). - applyClaudeOAuthHeaderDefaults(req) - // Then force key headers to match Claude Code fingerprint regardless of what the client sent. - // 使用 resolveWireCasing 确保 key 与真实 wire format 一致(如 "x-app" 而非 "X-App") - for key, value := range claude.DefaultHeaders { - if value == "" { - continue - } - setHeaderRaw(req.Header, resolveWireCasing(key), value) - } - // Real Claude CLI uses Accept: application/json (even for streaming). - setHeaderRaw(req.Header, "Accept", "application/json") - if isStream { - setHeaderRaw(req.Header, "x-stainless-helper-method", "stream") - } - // Real Claude CLI 每个请求都会生成一个新的 UUID 放在 x-client-request-id。 - // 上游会以此作为会话/请求指纹的一部分,缺失或重复都可能触发第三方判定。 - if getHeaderRaw(req.Header, "x-client-request-id") == "" { - setHeaderRaw(req.Header, "x-client-request-id", uuid.NewString()) - } -} - -func truncateForLog(b []byte, maxBytes int) string { - if maxBytes <= 0 { - maxBytes = 2048 - } - if len(b) > maxBytes { - b = b[:maxBytes] - } - s := string(b) - // 保持一行,避免污染日志格式 - s = strings.ReplaceAll(s, "\n", "\\n") - s = strings.ReplaceAll(s, "\r", "\\r") - return s -} - -// shouldRectifySignatureError 统一判断是否应触发签名整流(strip thinking blocks 并重试)。 -// 根据账号类型检查对应的开关和匹配模式。 -func (s *GatewayService) shouldRectifySignatureError(ctx context.Context, account *Account, respBody []byte) bool { - if account.Type == AccountTypeAPIKey { - // API Key 账号:独立开关,一次读取配置 - settings, err := s.settingService.GetRectifierSettings(ctx) - if err != nil || !settings.Enabled || !settings.APIKeySignatureEnabled { - return false - } - // 先检查内置模式(同 OAuth),再检查自定义关键词 - if s.isThinkingBlockSignatureError(respBody) { - return true - } - return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns) - } - // OAuth/SetupToken/Upstream/Bedrock 等:保持原有行为(内置模式 + 原开关) - return s.isThinkingBlockSignatureError(respBody) && s.settingService.IsSignatureRectifierEnabled(ctx) -} - -// isSignatureErrorPattern 仅做模式匹配,不检查开关。 -// 用于已进入重试流程后的二阶段检测(此时开关已在首次调用时验证过)。 -func (s *GatewayService) isSignatureErrorPattern(ctx context.Context, account *Account, respBody []byte) bool { - if s.isThinkingBlockSignatureError(respBody) { - return true - } - if account.Type == AccountTypeAPIKey { - settings, err := s.settingService.GetRectifierSettings(ctx) - if err != nil { - return false - } - return matchSignaturePatterns(respBody, settings.APIKeySignaturePatterns) - } - return false -} - -// matchSignaturePatterns 检查响应体是否匹配自定义关键词列表(不区分大小写)。 -func matchSignaturePatterns(respBody []byte, patterns []string) bool { - if len(patterns) == 0 { - return false - } - bodyLower := strings.ToLower(string(respBody)) - for _, p := range patterns { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if strings.Contains(bodyLower, strings.ToLower(p)) { - return true - } - } - return false -} - -// isThinkingBlockSignatureError 检测是否是thinking block相关错误 -// 这类错误可以通过过滤thinking blocks并重试来解决 -func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { - msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) - if msg == "" { - return false - } - - // 检测signature相关的错误(更宽松的匹配) - // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 - if strings.Contains(msg, "signature") { - return true - } - - // 检测 thinking block 顺序/类型错误 - // 例如: "Expected `thinking` or `redacted_thinking`, but found `text`" - if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block type error") - return true - } - - // 检测 thinking block 被修改的错误 - // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" - if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block modification error") - return true - } - - // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的,或客户端发送了空 text block) - // 例如: "all messages must have non-empty content" - // "messages: text content blocks must be non-empty" - if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") || - strings.Contains(msg, "content blocks must be non-empty") { - logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") - return true - } - - return false -} - -func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 - // 默认保守:无法识别则不切换。 - msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) - if msg == "" { - return false - } - - // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 - // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 - if strings.Contains(msg, "anthropic-beta") || - strings.Contains(msg, "beta feature") || - strings.Contains(msg, "requires beta") { - return true - } - - // thinking/tool streaming 等兼容性约束(常见于中间转换链路) - if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { - return true - } - if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { - return true - } - - return false -} - -// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。 -// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游 -// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection -// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。 -func sanitizeStreamError(err error) string { - if err == nil { - return "" - } - switch { - case errors.Is(err, io.ErrUnexpectedEOF): - return "unexpected EOF" - case errors.Is(err, io.EOF): - return "EOF" - case errors.Is(err, context.Canceled): - return "canceled" - case errors.Is(err, context.DeadlineExceeded): - return "deadline exceeded" - case errors.Is(err, syscall.ECONNRESET): - return "connection reset by peer" - case errors.Is(err, syscall.ECONNABORTED): - return "connection aborted" - case errors.Is(err, syscall.ETIMEDOUT): - return "connection timed out" - case errors.Is(err, syscall.EPIPE): - return "broken pipe" - case errors.Is(err, syscall.ECONNREFUSED): - return "connection refused" - } - var netErr *net.OpError - if errors.As(err, &netErr) { - if netErr.Timeout() { - if netErr.Op != "" { - return netErr.Op + " timeout" - } - return "i/o timeout" - } - if netErr.Op != "" { - return netErr.Op + " network error" - } - } - return "upstream connection error" -} - -// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 -// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} -func ExtractUpstreamErrorMessage(body []byte) string { - return extractUpstreamErrorMessage(body) -} - -func extractUpstreamErrorMessage(body []byte) string { - // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} - if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { - inner := strings.TrimSpace(m) - // 有些上游会把完整 JSON 作为字符串塞进 message - if strings.HasPrefix(inner, "{") { - if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { - return innerMsg - } - } - return m - } - - // ChatGPT 内部 API 风格:{"detail":"..."} - if d := gjson.GetBytes(body, "detail").String(); strings.TrimSpace(d) != "" { - return d - } - - // 兜底:尝试顶层 message - return gjson.GetBytes(body, "message").String() -} - -func extractUpstreamErrorCode(body []byte) string { - if code := strings.TrimSpace(gjson.GetBytes(body, "error.code").String()); code != "" { - return code - } - - inner := strings.TrimSpace(gjson.GetBytes(body, "error.message").String()) - if !strings.HasPrefix(inner, "{") { - return "" - } - - if code := strings.TrimSpace(gjson.Get(inner, "error.code").String()); code != "" { - return code - } - - if lastBrace := strings.LastIndex(inner, "}"); lastBrace >= 0 { - if code := strings.TrimSpace(gjson.Get(inner[:lastBrace+1], "error.code").String()); code != "" { - return code - } - } - - return "" -} - -func isCountTokensUnsupported404(statusCode int, body []byte) bool { - if statusCode != http.StatusNotFound { - return false - } - msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body))) - if msg == "" { - return false - } - if strings.Contains(msg, "/v1/messages/count_tokens") { - return true - } - return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") -} - -func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - // 调试日志:打印上游错误响应 - logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", - account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - - // Print a compact upstream request fingerprint when we hit the Claude Code OAuth - // credential scope error. This avoids requiring env-var tweaks in a fixed deploy. - if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { - if v, ok := c.Get(claudeMimicDebugInfoKey); ok { - if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", - resp.StatusCode, - resp.Header.Get("x-request-id"), - line, - ) - } - } - } - - // Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet. - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(body), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - - // 处理上游错误,标记账号状态 - shouldDisable := false - if s.rateLimitService != nil { - shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } - if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} - } - - // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - logger.LegacyPrintf("service.gateway", - "Upstream error %d (account=%d platform=%s type=%s): %s", - resp.StatusCode, - account.ID, - account.Platform, - account.Type, - truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } - - // 非 failover 错误也支持错误透传规则匹配。 - if status, errType, errMsg, matched := applyErrorPassthroughRule( - c, - account.Platform, - resp.StatusCode, - body, - http.StatusBadGateway, - "upstream_error", - "Upstream request failed", - ); matched { - c.JSON(status, gin.H{ - "type": "error", - "error": gin.H{ - "type": errType, - "message": errMsg, - }, - }) - - summary := upstreamMsg - if summary == "" { - summary = errMsg - } - if summary == "" { - return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary) - } - - // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) - var errType, errMsg string - var statusCode int - - switch resp.StatusCode { - case 400: - c.Data(http.StatusBadRequest, "application/json", body) - summary := upstreamMsg - if summary == "" { - summary = truncateForLog(body, 512) - } - if summary == "" { - return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, summary) - case 401: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream authentication failed, please contact administrator" - case 403: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream access forbidden, please contact administrator" - case 429: - statusCode = http.StatusTooManyRequests - errType = "rate_limit_error" - errMsg = "Upstream rate limit exceeded, please retry later" - case 529: - statusCode = http.StatusServiceUnavailable - errType = "overloaded_error" - errMsg = "Upstream service overloaded, please retry later" - case 500, 502, 503, 504: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream service temporarily unavailable" - default: - statusCode = http.StatusBadGateway - errType = "upstream_error" - errMsg = "Upstream request failed" - } - - // 返回自定义错误响应 - c.JSON(statusCode, gin.H{ - "type": "error", - "error": gin.H{ - "type": errType, - "message": errMsg, - }, - }) - - if upstreamMsg == "" { - return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) -} - -func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - statusCode := resp.StatusCode - - // OAuth/Setup Token 账号的 403:标记账号异常 - if account.IsOAuth() && statusCode == 403 { - s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) - logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) - } else { - // API Key 未配置错误码:不标记账号状态 - logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) - } -} - -func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) -} - -// handleRetryExhaustedError 处理重试耗尽后的错误 -// OAuth 403:标记账号异常 -// API Key 未配置错误码:仅返回错误,不标记账号 -func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { - // Capture upstream error body before side-effects consume the stream. - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - s.handleRetryExhaustedSideEffects(ctx, resp, account) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - - if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { - if v, ok := c.Get(claudeMimicDebugInfoKey); ok { - if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", - resp.StatusCode, - resp.Header.Get("x-request-id"), - line, - ) - } - } - } - - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(respBody), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry_exhausted", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - logger.LegacyPrintf("service.gateway", - "Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s", - resp.StatusCode, - account.ID, - account.Platform, - account.Type, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) - } - - if status, errType, errMsg, matched := applyErrorPassthroughRule( - c, - account.Platform, - resp.StatusCode, - respBody, - http.StatusBadGateway, - "upstream_error", - "Upstream request failed after retries", - ); matched { - c.JSON(status, gin.H{ - "type": "error", - "error": gin.H{ - "type": errType, - "message": errMsg, - }, - }) - - summary := upstreamMsg - if summary == "" { - summary = errMsg - } - if summary == "" { - return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary) - } - - // 返回统一的重试耗尽错误响应 - c.JSON(http.StatusBadGateway, gin.H{ - "type": "error", - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed after retries", - }, - }) - - if upstreamMsg == "" { - return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode) - } - return nil, fmt.Errorf("upstream error: %d (retries exhausted) message=%s", resp.StatusCode, upstreamMsg) -} - -// streamingResult 流式响应结果 -type streamingResult struct { - usage *ClaudeUsage - firstTokenMs *int - clientDisconnect bool // 客户端是否在流式传输过程中断开 -} - -func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) { - // 更新5h窗口状态 - s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - - if s.responseHeaderFilter != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - } - - // 设置SSE响应头 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - - // 透传其他响应头 - if v := resp.Header.Get("x-request-id"); v != "" { - c.Header("x-request-id", v) - } - - w := c.Writer - flusher, ok := w.(http.Flusher) - if !ok { - return nil, errors.New("streaming not supported") - } - - usage := &ClaudeUsage{} - var firstTokenMs *int - scanner := bufio.NewScanner(resp.Body) - // 设置更大的buffer以处理长行 - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanBuf := getSSEScannerBuf64K() - scanner.Buffer(scanBuf[:0], maxLineSize) - - type scanEvent struct { - line string - err error - } - // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理 - events := make(chan scanEvent, 16) - done := make(chan struct{}) - sendEvent := func(ev scanEvent) bool { - select { - case events <- ev: - return true - case <-done: - return false - } - } - var lastReadAt int64 - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func(scanBuf *sseScannerBuf64K) { - defer putSSEScannerBuf64K(scanBuf) - defer close(events) - for scanner.Scan() { - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - if !sendEvent(scanEvent{line: scanner.Text()}) { - return - } - } - if err := scanner.Err(); err != nil { - _ = sendEvent(scanEvent{err: err}) - } - }(scanBuf) - defer close(done) - - streamInterval := time.Duration(0) - if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { - streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second - } - // 仅监控上游数据间隔超时,避免下游写入阻塞导致误判 - var intervalTicker *time.Ticker - if streamInterval > 0 { - intervalTicker = time.NewTicker(streamInterval) - defer intervalTicker.Stop() - } - var intervalCh <-chan time.Time - if intervalTicker != nil { - intervalCh = intervalTicker.C - } - - // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 - keepaliveInterval := time.Duration(0) - if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { - keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second - } - var keepaliveTicker *time.Ticker - if keepaliveInterval > 0 { - keepaliveTicker = time.NewTicker(keepaliveInterval) - defer keepaliveTicker.Stop() - } - var keepaliveCh <-chan time.Time - if keepaliveTicker != nil { - keepaliveCh = keepaliveTicker.C - } - lastDataAt := time.Now() - - // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。 - // 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":,"message":}} - // 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案, - // 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。 - errorEventSent := false - sendErrorEvent := func(reason, message string) { - if errorEventSent { - return - } - errorEventSent = true - if message == "" { - message = reason - } - body, err := json.Marshal(map[string]any{ - "type": "error", - "error": map[string]string{ - "type": reason, - "message": message, - }, - }) - if err != nil { - // json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback - body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message)) - } - _, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body) - flusher.Flush() +// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 +// 处理 null、字符串、数组三种格式 +func injectClaudeCodePrompt(body []byte, system any) []byte { + system = normalizeSystemParam(system) + claudeCodeBlock, err := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) + if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build Claude Code prompt block: %v", err) + return body } + // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code + // banner, it also prefixes the next system instruction with the same banner plus + // a blank line. This helps when upstream concatenates system instructions. + claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) - needModelReplace := originalModel != mappedModel - clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage - sawTerminalEvent := false - - pendingEventLines := make([]string, 0, 4) - - processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) { - if len(lines) == 0 { - return nil, "", nil, nil - } - - eventName := "" - dataLine := "" - for _, line := range lines { - trimmed := strings.TrimSpace(line) - if strings.HasPrefix(trimmed, "event:") { - eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) - continue - } - if dataLine == "" && sseDataRe.MatchString(trimmed) { - dataLine = sseDataRe.ReplaceAllString(trimmed, "") - } - } - - if eventName == "error" { - return nil, dataLine, nil, errors.New("have error in stream") - } - - if dataLine == "" { - return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil - } - - if dataLine == "[DONE]" { - sawTerminalEvent = true - block := "" - if eventName != "" { - block = "event: " + eventName + "\n" - } - block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil, nil - } - - var event map[string]any - if err := json.Unmarshal([]byte(dataLine), &event); err != nil { - // JSON 解析失败,直接透传原始数据 - block := "" - if eventName != "" { - block = "event: " + eventName + "\n" - } - block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil, nil - } - - eventType, _ := event["type"].(string) - if eventName == "" { - eventName = eventType - } - eventChanged := false - - // 兼容 Kimi cached_tokens → cache_read_input_tokens - if eventType == "message_start" { - if msg, ok := event["message"].(map[string]any); ok { - if u, ok := msg["usage"].(map[string]any); ok { - eventChanged = reconcileCachedTokens(u) || eventChanged - } - } - } - if eventType == "message_delta" { - if u, ok := event["usage"].(map[string]any); ok { - eventChanged = reconcileCachedTokens(u) || eventChanged - } - } - - // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。 - // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 - if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { - if eventType == "message_start" { - if msg, ok := event["message"].(map[string]any); ok { - if u, ok := msg["usage"].(map[string]any); ok { - eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged - } - } - } - if eventType == "message_delta" { - if u, ok := event["usage"].(map[string]any); ok { - eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged - } - } - } - - if needModelReplace { - if msg, ok := event["message"].(map[string]any); ok { - if model, ok := msg["model"].(string); ok && model == mappedModel { - msg["model"] = originalModel - eventChanged = true - } - } - } + var items [][]byte - usagePatch := s.extractSSEUsagePatch(event) - if anthropicStreamEventIsTerminal(eventName, dataLine) { - sawTerminalEvent = true - } - if !eventChanged { - block := "" - if eventName != "" { - block = "event: " + eventName + "\n" + switch v := system.(type) { + case nil: + items = [][]byte{claudeCodeBlock} + case string: + // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. + if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { + items = [][]byte{claudeCodeBlock} + } else { + // Mirror opencode behavior: keep the banner as a separate system entry, + // but also prefix the next system text with the banner. + merged := v + if !strings.HasPrefix(v, claudeCodePrefix) { + merged = claudeCodePrefix + "\n\n" + v } - block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, usagePatch, nil - } - - newData, err := json.Marshal(event) - if err != nil { - // 序列化失败,直接透传原始数据 - block := "" - if eventName != "" { - block = "event: " + eventName + "\n" + nextBlock, buildErr := marshalAnthropicSystemTextBlock(merged, false) + if buildErr != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build prefixed Claude Code system block: %v", buildErr) + return body } - block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, usagePatch, nil - } - - block := "" - if eventName != "" { - block = "event: " + eventName + "\n" + items = [][]byte{claudeCodeBlock, nextBlock} } - block += "data: " + string(newData) + "\n\n" - return []string{block}, string(newData), usagePatch, nil - } - - for { - select { - case ev, ok := <-events: - if !ok { - // 上游完成,返回结果 - if !sawTerminalEvent { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") - } - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil - } - if ev.err != nil { - if sawTerminalEvent { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil - } - // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) - if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) - } - // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) - } - // 客户端未断开,正常的错误处理 - if errors.Is(ev.err, bufio.ErrTooLong) { - logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize)) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + case []any: + items = make([][]byte, 0, len(v)+1) + items = append(items, claudeCodeBlock) + prefixedNext := false + systemResult := gjson.GetBytes(body, "system") + if systemResult.IsArray() { + systemResult.ForEach(func(_, item gjson.Result) bool { + textResult := item.Get("text") + if textResult.Exists() && textResult.Type == gjson.String && + strings.TrimSpace(textResult.String()) == strings.TrimSpace(claudeCodeSystemPrompt) { + return true } - // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY): - // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。 - // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。 - // 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址, - // 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err - // 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。 - disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err) - if !c.Writer.Written() { - logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err) - body, _ := json.Marshal(map[string]any{ - "type": "error", - "error": map[string]string{ - "type": "upstream_disconnected", - "message": disconnectMsg, - }, - }) - return nil, &UpstreamFailoverError{ - StatusCode: http.StatusBadGateway, - ResponseBody: body, - RetryableOnSameAccount: true, + + raw := []byte(item.Raw) + // Prefix the first subsequent text system block once. + if !prefixedNext && item.Get("type").String() == "text" && textResult.Exists() && textResult.Type == gjson.String { + text := textResult.String() + if strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + next, setErr := sjson.SetBytes(raw, "text", claudeCodePrefix+"\n\n"+text) + if setErr == nil { + raw = next + prefixedNext = true + } } } - sendErrorEvent("stream_read_error", disconnectMsg) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) - } - line := ev.line - trimmed := strings.TrimSpace(line) - - if trimmed == "" { - if len(pendingEventLines) == 0 { + items = append(items, raw) + return true + }) + } else { + for _, item := range v { + m, ok := item.(map[string]any) + if !ok { + raw, marshalErr := json.Marshal(item) + if marshalErr == nil { + items = append(items, raw) + } continue } - - outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines) - pendingEventLines = pendingEventLines[:0] - if err != nil { - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil - } - return nil, err + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { + continue } - - for _, block := range outputBlocks { - if !clientDisconnected { - restored := reverseToolNamesIfPresent(c, []byte(block)) - if _, werr := fmt.Fprint(w, string(restored)); werr != nil { - clientDisconnected = true - logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") - break - } - flusher.Flush() - lastDataAt = time.Now() - } - if data != "" { - if firstTokenMs == nil && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - if usagePatch != nil { - mergeSSEUsagePatch(usage, usagePatch) + if !prefixedNext { + if blockType, _ := m["type"].(string); blockType == "text" { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { + m["text"] = claudeCodePrefix + "\n\n" + text + prefixedNext = true } } } - continue - } - - pendingEventLines = append(pendingEventLines, line) - - case <-intervalCh: - lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) - if time.Since(lastRead) < streamInterval { - continue - } - if clientDisconnected { - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") - } - logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) - // 处理流超时,可能标记账户为临时不可调度或错误状态 - if s.rateLimitService != nil { - s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) - } - sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval)) - return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") - - case <-keepaliveCh: - if clientDisconnected { - continue - } - if time.Since(lastDataAt) < keepaliveInterval { - continue - } - // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, - // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 - if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil { - clientDisconnected = true - logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing") - continue - } - flusher.Flush() - } - } - -} - -func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { - if usage == nil { - return - } - - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return - } - - if patch := s.extractSSEUsagePatch(event); patch != nil { - mergeSSEUsagePatch(usage, patch) - } -} - -type sseUsagePatch struct { - inputTokens int - hasInputTokens bool - outputTokens int - hasOutputTokens bool - cacheCreationInputTokens int - hasCacheCreationInput bool - cacheReadInputTokens int - hasCacheReadInput bool - cacheCreation5mTokens int - hasCacheCreation5m bool - cacheCreation1hTokens int - hasCacheCreation1h bool -} - -func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch { - if len(event) == 0 { - return nil - } - - eventType, _ := event["type"].(string) - switch eventType { - case "message_start": - msg, _ := event["message"].(map[string]any) - usageObj, _ := msg["usage"].(map[string]any) - if len(usageObj) == 0 { - return nil - } - - patch := &sseUsagePatch{} - patch.hasInputTokens = true - if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok { - patch.inputTokens = v - } - patch.hasCacheCreationInput = true - if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok { - patch.cacheCreationInputTokens = v - } - patch.hasCacheReadInput = true - if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok { - patch.cacheReadInputTokens = v - } - if cc, ok := usageObj["cache_creation"].(map[string]any); ok { - if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists { - patch.cacheCreation5mTokens = v - patch.hasCacheCreation5m = true - } - if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists { - patch.cacheCreation1hTokens = v - patch.hasCacheCreation1h = true - } - } - return patch - - case "message_delta": - usageObj, _ := event["usage"].(map[string]any) - if len(usageObj) == 0 { - return nil - } - - patch := &sseUsagePatch{} - if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 { - patch.inputTokens = v - patch.hasInputTokens = true - } - if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 { - patch.outputTokens = v - patch.hasOutputTokens = true - } - if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 { - patch.cacheCreationInputTokens = v - patch.hasCacheCreationInput = true - } - if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 { - patch.cacheReadInputTokens = v - patch.hasCacheReadInput = true - } - if cc, ok := usageObj["cache_creation"].(map[string]any); ok { - if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 { - patch.cacheCreation5mTokens = v - patch.hasCacheCreation5m = true - } - if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 { - patch.cacheCreation1hTokens = v - patch.hasCacheCreation1h = true - } - } - return patch - } - - return nil -} - -func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) { - if usage == nil || patch == nil { - return - } - - if patch.hasInputTokens { - usage.InputTokens = patch.inputTokens - } - if patch.hasCacheCreationInput { - usage.CacheCreationInputTokens = patch.cacheCreationInputTokens - } - if patch.hasCacheReadInput { - usage.CacheReadInputTokens = patch.cacheReadInputTokens - } - if patch.hasOutputTokens { - usage.OutputTokens = patch.outputTokens - } - if patch.hasCacheCreation5m { - usage.CacheCreation5mTokens = patch.cacheCreation5mTokens - } - if patch.hasCacheCreation1h { - usage.CacheCreation1hTokens = patch.cacheCreation1hTokens - } -} - -func parseSSEUsageInt(value any) (int, bool) { - switch v := value.(type) { - case float64: - return int(v), true - case float32: - return int(v), true - case int: - return v, true - case int64: - return int(v), true - case int32: - return int(v), true - case json.Number: - if i, err := v.Int64(); err == nil { - return int(i), true - } - if f, err := v.Float64(); err == nil { - return int(f), true - } - case string: - if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { - return parsed, true - } - } - return 0, false -} - -// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 -// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。 -func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { - // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别 - if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 { - usage.CacheCreation5mTokens = usage.CacheCreationInputTokens - } - - total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens - if total == 0 { - return false - } - switch target { - case "1h": - if usage.CacheCreation1hTokens == total { - return false // 已经全是 1h - } - usage.CacheCreation1hTokens = total - usage.CacheCreation5mTokens = 0 - default: // "5m" - if usage.CacheCreation5mTokens == total { - return false // 已经全是 5m - } - usage.CacheCreation5mTokens = total - usage.CacheCreation1hTokens = 0 - } - return true -} - -// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 -// usageObj 是 usage JSON 对象(map[string]any)。 -func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { - ccObj, ok := usageObj["cache_creation"].(map[string]any) - if !ok { - return false - } - v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"]) - v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"]) - total := v5m + v1h - if total == 0 { - return false - } - switch target { - case "1h": - if v1h == total { - return false - } - ccObj["ephemeral_1h_input_tokens"] = float64(total) - ccObj["ephemeral_5m_input_tokens"] = float64(0) - default: // "5m" - if v5m == total { - return false - } - ccObj["ephemeral_5m_input_tokens"] = float64(total) - ccObj["ephemeral_1h_input_tokens"] = float64(0) - } - return true -} - -func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) { - if account == nil { - return "", false - } - if account.IsCacheTTLOverrideEnabled() { - return account.GetCacheTTLOverrideTarget(), true - } - if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) { - return cacheTTLTarget5m, true - } - return "", false -} - -func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { - // 更新5h窗口状态 - s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - - body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) - if err != nil { - return nil, err - } - - // 解析usage - var response struct { - Usage ClaudeUsage `json:"usage"` - } - if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 - cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens") - cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() || cc1h.Exists() { - response.Usage.CacheCreation5mTokens = int(cc5m.Int()) - response.Usage.CacheCreation1hTokens = int(cc1h.Int()) - } - - // 兼容 Kimi cached_tokens → cache_read_input_tokens - if response.Usage.CacheReadInputTokens == 0 { - cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() - if cachedTokens > 0 { - response.Usage.CacheReadInputTokens = int(cachedTokens) - if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { - body = newBody - } - } - } - - // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。 - // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 - if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { - if applyCacheTTLOverride(&response.Usage, overrideTarget) { - // 同步更新 body JSON 中的嵌套 cache_creation 对象 - if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { - body = newBody - } - if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil { - body = newBody - } - } - } - - // 如果有模型映射,替换响应中的model字段 - if originalModel != mappedModel { - body = s.replaceModelInResponseBody(body, mappedModel, originalModel) - } - - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - - contentType := "application/json" - if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { - if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { - contentType = upstreamType - } - } - - body = reverseToolNamesIfPresent(c, body) - - // 写入响应 - c.Data(resp.StatusCode, contentType, body) - - return &response.Usage, nil -} - -// replaceModelInResponseBody 替换响应体中的model字段 -// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化 -func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { - newBody, err := sjson.SetBytes(body, "model", toModel) - if err != nil { - return body - } - return newBody - } - return body -} - -func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { - if s == nil { - return groupDefaultMultiplier - } - resolver := s.userGroupRateResolver - if resolver == nil { - resolver = newUserGroupRateResolver( - s.userGroupRateRepo, - s.userGroupRateCache, - resolveUserGroupRateCacheTTL(s.cfg), - &s.userGroupRateSF, - "service.gateway", - ) - } - return resolver.Resolve(ctx, userID, groupID, groupDefaultMultiplier) -} - -// RecordUsageInput 记录使用量的输入参数 -type RecordUsageInput struct { - Result *ForwardResult - ParsedRequest *ParsedRequest - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - InboundEndpoint string // 入站端点(客户端请求路径) - UpstreamEndpoint string // 上游端点(标准化后的上游路径) - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 - - ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) -} - -// APIKeyQuotaUpdater defines the interface for updating API Key quota and rate limit usage -type APIKeyQuotaUpdater interface { - UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error - UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error -} - -type apiKeyAuthCacheInvalidator interface { - InvalidateAuthCacheByKey(ctx context.Context, key string) -} - -type usageLogBestEffortWriter interface { - CreateBestEffort(ctx context.Context, log *UsageLog) error -} - -// postUsageBillingParams 统一扣费所需的参数 -type postUsageBillingParams struct { - Cost *CostBreakdown - User *User - APIKey *APIKey - Account *Account - Subscription *UserSubscription - RequestPayloadHash string - IsSubscriptionBill bool - AccountRateMultiplier float64 - APIKeyService APIKeyQuotaUpdater -} - -func (p *postUsageBillingParams) shouldDeductAPIKeyQuota() bool { - return p.Cost.ActualCost > 0 && p.APIKey.Quota > 0 && p.APIKeyService != nil -} - -func (p *postUsageBillingParams) shouldUpdateRateLimits() bool { - return p.Cost.ActualCost > 0 && p.APIKey.HasRateLimits() && p.APIKeyService != nil -} - -func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool { - return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit() -} - -// postUsageBilling is the legacy fallback billing path used when the unified -// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply -// for atomic billing. This path only runs in tests or degraded mode. -func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) { - billingCtx, cancel := detachedBillingContext(ctx) - defer cancel() - - cost := p.Cost - - if p.IsSubscriptionBill { - // Subscription usage tracked by ActualCost so group rate multiplier - // consumes the quota at the expected speed. - if cost.ActualCost > 0 { - if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.ActualCost); err != nil { - slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err) - } - } - } else { - if cost.ActualCost > 0 { - if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil { - slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err) - } - } - } - - if p.shouldDeductAPIKeyQuota() { - if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { - slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err) - } - } - - if p.shouldUpdateRateLimits() { - if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil { - slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err) - } - } - - if p.shouldUpdateAccountQuota() { - accountCost := cost.TotalCost * p.AccountRateMultiplier - if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil { - slog.Error("increment account quota used failed", "account_id", p.Account.ID, "cost", accountCost, "error", err) - } - } - - // NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing - // cache updates. The legacy path does DB writes directly; the finalize path - // does cache queue + notifications. Notifications are dispatched separately - // by the caller after recording the usage log. -} - -func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string { - if ctx != nil { - if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { - return "client:" + strings.TrimSpace(clientRequestID) - } - if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { - return "local:" + strings.TrimSpace(requestID) - } - } - if requestID := strings.TrimSpace(upstreamRequestID); requestID != "" { - return requestID - } - return "generated:" + generateRequestID() -} - -func resolveUsageBillingPayloadFingerprint(ctx context.Context, requestPayloadHash string) string { - if payloadHash := strings.TrimSpace(requestPayloadHash); payloadHash != "" { - return payloadHash - } - if ctx != nil { - if clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string); strings.TrimSpace(clientRequestID) != "" { - return "client:" + strings.TrimSpace(clientRequestID) - } - if requestID, _ := ctx.Value(ctxkey.RequestID).(string); strings.TrimSpace(requestID) != "" { - return "local:" + strings.TrimSpace(requestID) - } - } - return "" -} - -func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsageBillingParams) *UsageBillingCommand { - if p == nil || p.Cost == nil || p.APIKey == nil || p.User == nil || p.Account == nil { - return nil - } - - cmd := &UsageBillingCommand{ - RequestID: requestID, - APIKeyID: p.APIKey.ID, - UserID: p.User.ID, - AccountID: p.Account.ID, - AccountType: p.Account.Type, - RequestPayloadHash: strings.TrimSpace(p.RequestPayloadHash), - } - if usageLog != nil { - cmd.Model = usageLog.Model - cmd.BillingType = usageLog.BillingType - cmd.InputTokens = usageLog.InputTokens - cmd.OutputTokens = usageLog.OutputTokens - cmd.CacheCreationTokens = usageLog.CacheCreationTokens - cmd.CacheReadTokens = usageLog.CacheReadTokens - cmd.ImageCount = usageLog.ImageCount - if usageLog.ServiceTier != nil { - cmd.ServiceTier = *usageLog.ServiceTier - } - if usageLog.ReasoningEffort != nil { - cmd.ReasoningEffort = *usageLog.ReasoningEffort - } - if usageLog.SubscriptionID != nil { - cmd.SubscriptionID = usageLog.SubscriptionID + raw, marshalErr := json.Marshal(m) + if marshalErr == nil { + items = append(items, raw) + } + } } + default: + items = [][]byte{claudeCodeBlock} } - // Record subscription / balance cost using ActualCost so the group (and any - // user-specific) rate multiplier consumes subscription quota at the expected - // speed. TotalCost remains the raw (pre-multiplier) value; downstream guards - // on "> 0" still correctly skip free subscriptions (RateMultiplier == 0). - if p.IsSubscriptionBill && p.Subscription != nil && p.Cost.TotalCost > 0 { - cmd.SubscriptionID = &p.Subscription.ID - cmd.SubscriptionCost = p.Cost.ActualCost - } else if p.Cost.ActualCost > 0 { - cmd.BalanceCost = p.Cost.ActualCost - } - - if p.shouldDeductAPIKeyQuota() { - cmd.APIKeyQuotaCost = p.Cost.ActualCost - } - if p.shouldUpdateRateLimits() { - cmd.APIKeyRateLimitCost = p.Cost.ActualCost - } - if p.shouldUpdateAccountQuota() { - cmd.AccountQuotaCost = p.Cost.TotalCost * p.AccountRateMultiplier + result, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw(items)) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt") + return body } - - cmd.Normalize() - return cmd + return result } -func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog, p *postUsageBillingParams, deps *billingDeps, repo UsageBillingRepository) (bool, error) { - if p == nil || deps == nil { - return false, nil - } +// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages, +// system 字段仅保留 Claude Code 标识提示词。 +// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词 +// 无法通过检测,因为后续内容仍为非 Claude Code 格式。 +// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。 +func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { + system = normalizeSystemParam(system) - cmd := buildUsageBillingCommand(requestID, usageLog, p) - if cmd == nil || cmd.RequestID == "" || repo == nil { - postUsageBilling(ctx, p, deps) - return true, nil + // 1. 提取原始 system prompt 文本 + var originalSystemText string + switch v := system.(type) { + case string: + originalSystemText = strings.TrimSpace(v) + case []any: + var parts []string + for _, item := range v { + if m, ok := item.(map[string]any); ok { + if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" { + parts = append(parts, text) + } + } + } + originalSystemText = strings.Join(parts, "\n\n") } - billingCtx, cancel := detachedBillingContext(ctx) - defer cancel() - - result, err := repo.Apply(billingCtx, cmd) - if err != nil { - return false, err + // 2. 构造 system 数组,对齐真实 Claude Code CLI 的 2-block 形态: + // [0] billing attribution block(cc_version={cliVer}.{fp}; cc_entrypoint=cli; cch=00000;) + // [1] "You are Claude Code..." prompt block(带 cache_control 作为稳定缓存断点) + // + // billing block 的 cch=00000 是占位符,会被 buildUpstreamRequest 里的 + // signBillingHeaderCCH 替换成 xxhash64 签名。缺失 billing block 的系统 payload + // 是 Anthropic 判定第三方的关键信号之一(真实 CLI 每个请求都带)。 + billingBlock, billingErr := buildBillingAttributionBlockJSON(body, claude.CLICurrentVersion) + ccPromptBlock, ccErr := marshalAnthropicSystemTextBlock(claudeCodeSystemPrompt, true) + if billingErr != nil || ccErr != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to build system blocks (billing=%v, cc=%v)", billingErr, ccErr) + return body } - - if result == nil || !result.Applied { - deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) - return false, nil + out, ok := setJSONRawBytes(body, "system", buildJSONArrayRaw([][]byte{billingBlock, ccPromptBlock})) + if !ok { + logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") + return body } - if result.APIKeyQuotaExhausted { - if invalidator, ok := p.APIKeyService.(apiKeyAuthCacheInvalidator); ok && p.APIKey != nil && p.APIKey.Key != "" { - invalidator.InvalidateAuthCacheByKey(billingCtx, p.APIKey.Key) + // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头 + // 模型仍通过 messages 接收完整指令,保留客户端功能 + ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt) + if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) { + instrMsg, err1 := json.Marshal(map[string]any{ + "role": "user", + "content": []map[string]any{ + {"type": "text", "text": "[System Instructions]\n" + originalSystemText}, + }, + }) + ackMsg, err2 := json.Marshal(map[string]any{ + "role": "assistant", + "content": []map[string]any{ + {"type": "text", "text": "Understood. I will follow these instructions."}, + }, + }) + if err1 != nil || err2 != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection") + return out } - } - - finalizePostUsageBilling(p, deps, result) - return true, nil -} - -func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { - if p == nil || p.Cost == nil || deps == nil { - return - } - if p.IsSubscriptionBill { - if p.Cost.ActualCost > 0 && p.User != nil && p.APIKey != nil && p.APIKey.GroupID != nil { - deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, p.Cost.ActualCost) + // 重建 messages 数组:[instruction, ack, ...originalMessages] + items := [][]byte{instrMsg, ackMsg} + messagesResult := gjson.GetBytes(out, "messages") + if messagesResult.IsArray() { + messagesResult.ForEach(func(_, msg gjson.Result) bool { + items = append(items, []byte(msg.Raw)) + return true + }) } - } else if p.Cost.ActualCost > 0 && p.User != nil { - deps.billingCacheService.QueueDeductBalance(p.User.ID, p.Cost.ActualCost) - } - - if p.Cost.ActualCost > 0 && p.APIKey != nil && p.APIKey.HasRateLimits() { - deps.billingCacheService.QueueUpdateAPIKeyRateLimitUsage(p.APIKey.ID, p.Cost.ActualCost) - } - - deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID) - // Notification checks run async — all parameters are already captured, - // no dependency on the request context or upstream connection. - go notifyBalanceLow(p, deps, result) - go notifyAccountQuota(p, deps, result) -} - -// notifyBalanceLow sends balance low notification after deduction. -// When result.NewBalance is available (from DB transaction RETURNING), it is used directly -// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races. -func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { - defer func() { - if r := recover(); r != nil { - slog.Error("panic in notifyBalanceLow", "recover", r) + if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk { + out = next } - }() - if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil { - slog.Debug("notifyBalanceLow: skipped", - "is_subscription", p.IsSubscriptionBill, - "actual_cost", p.Cost.ActualCost, - "user_nil", p.User == nil, - "service_nil", deps.balanceNotifyService == nil, - ) - return } - oldBalance := resolveOldBalance(p, result) - slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction", - "user_id", p.User.ID, - "old_balance", oldBalance, - "cost", p.Cost.ActualCost, - "notify_enabled", p.User.BalanceNotifyEnabled, - "threshold", p.User.BalanceNotifyThreshold, - "result_has_new_balance", result != nil && result.NewBalance != nil, - ) - deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost) + return out } -// resolveOldBalance returns the pre-deduction balance. -// Prefers the DB transaction result (newBalance + cost) over snapshot. -func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 { - if result != nil && result.NewBalance != nil { - return *result.NewBalance + p.Cost.ActualCost - } - // Legacy fallback: snapshot balance from request context - return p.User.Balance +type cacheControlPath struct { + path string + log string } -// notifyAccountQuota sends account quota threshold notification after increment. -// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly -// to avoid a separate DB read that may see stale or concurrently-modified data. -func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) { - defer func() { - if r := recover(); r != nil { - slog.Error("panic in notifyAccountQuota", "recover", r) - } - }() - if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil { - slog.Debug("notifyAccountQuota: skipped", - "total_cost", p.Cost.TotalCost, - "account_nil", p.Account == nil, - "is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(), - "service_nil", deps.balanceNotifyService == nil, - ) - return - } - accountCost := p.Cost.TotalCost * p.AccountRateMultiplier - var quotaState *AccountQuotaState - if result != nil { - quotaState = result.QuotaState +func collectCacheControlPaths(body []byte) (invalidThinking []cacheControlPath, messagePaths []string, toolPaths []string, systemPaths []string) { + system := gjson.GetBytes(body, "system") + if system.IsArray() { + sysIndex := 0 + system.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("system.%d.cache_control", sysIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: "[Warning] Removed illegal cache_control from thinking block in system", + }) + } else { + systemPaths = append(systemPaths, path) + } + } + sysIndex++ + return true + }) } - slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement", - "account_id", p.Account.ID, - "account_cost", accountCost, - "has_quota_state", quotaState != nil, - ) - deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState) -} -func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) { - base := context.Background() - if ctx != nil { - base = context.WithoutCancel(ctx) + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + msgIndex := 0 + messages.ForEach(func(_, msg gjson.Result) bool { + content := msg.Get("content") + if content.IsArray() { + contentIndex := 0 + content.ForEach(func(_, item gjson.Result) bool { + if item.Get("cache_control").Exists() { + path := fmt.Sprintf("messages.%d.content.%d.cache_control", msgIndex, contentIndex) + if item.Get("type").String() == "thinking" { + invalidThinking = append(invalidThinking, cacheControlPath{ + path: path, + log: fmt.Sprintf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIndex, contentIndex), + }) + } else { + messagePaths = append(messagePaths, path) + } + } + contentIndex++ + return true + }) + } + msgIndex++ + return true + }) } - return context.WithTimeout(base, postUsageBillingTimeout) -} -func detachStreamUpstreamContext(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { - if ctx == nil { - return context.Background(), func() {} - } - if !stream { - return ctx, func() {} + tools := gjson.GetBytes(body, "tools") + if tools.IsArray() { + toolIndex := 0 + tools.ForEach(func(_, tool gjson.Result) bool { + if tool.Get("cache_control").Exists() { + toolPaths = append(toolPaths, fmt.Sprintf("tools.%d.cache_control", toolIndex)) + } + toolIndex++ + return true + }) } - return context.WithoutCancel(ctx), func() {} + + return invalidThinking, messagePaths, toolPaths, systemPaths } -func detachUpstreamContext(ctx context.Context) (context.Context, context.CancelFunc) { - if ctx == nil { - return context.Background(), func() {} +// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) +// 超限时优先移除工具断点,再移除 messages 断点,最后才移除 system 断点。 +func enforceCacheControlLimit(body []byte) []byte { + if len(body) == 0 { + return body } - return context.WithoutCancel(ctx), func() {} -} -// billingDeps 扣费逻辑依赖的服务(由各 gateway service 提供) -type billingDeps struct { - accountRepo AccountRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - billingCacheService *BillingCacheService - deferredService *DeferredService - balanceNotifyService *BalanceNotifyService -} + invalidThinking, messagePaths, toolPaths, systemPaths := collectCacheControlPaths(body) + out := body + modified := false -func (s *GatewayService) billingDeps() *billingDeps { - return &billingDeps{ - accountRepo: s.accountRepo, - userRepo: s.userRepo, - userSubRepo: s.userSubRepo, - billingCacheService: s.billingCacheService, - deferredService: s.deferredService, - balanceNotifyService: s.balanceNotifyService, + // 先清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) + for _, item := range invalidThinking { + if !gjson.GetBytes(out, item.path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, item.path) + if !ok { + continue + } + out = next + modified = true + logger.LegacyPrintf("service.gateway", "%s", item.log) } -} -func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usageLog *UsageLog, logKey string) { - if repo == nil || usageLog == nil { - return + count := len(messagePaths) + len(toolPaths) + len(systemPaths) + if count <= maxCacheControlBlocks { + if modified { + return out + } + return body } - usageCtx, cancel := detachedBillingContext(ctx) - defer cancel() - if writer, ok := repo.(usageLogBestEffortWriter); ok { - if err := writer.CreateBestEffort(usageCtx, usageLog); err != nil { - logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) - if IsUsageLogCreateDropped(err) { - return - } - if _, syncErr := repo.Create(usageCtx, usageLog); syncErr != nil { - logger.LegacyPrintf(logKey, "Create usage log sync fallback failed: %v", syncErr) - } + // 超限:优先从 tools 中移除,再从 messages 中移除,最后才从 system 中移除。 + remaining := count - maxCacheControlBlocks + for i := len(toolPaths) - 1; i >= 0 && remaining > 0; i-- { + path := toolPaths[i] + if !gjson.GetBytes(out, path).Exists() { + continue } - return + next, ok := deleteJSONPathBytes(out, path) + if !ok { + continue + } + out = next + modified = true + remaining-- } - if _, err := repo.Create(usageCtx, usageLog); err != nil { - logger.LegacyPrintf(logKey, "Create usage log failed: %v", err) + for _, path := range messagePaths { + if remaining <= 0 { + break + } + if !gjson.GetBytes(out, path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, path) + if !ok { + continue + } + out = next + modified = true + remaining-- } -} - -// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 -type recordUsageOpts struct { - // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) - ParsedRequest *ParsedRequest - - // EnableClaudePath 启用 Claude 路径特有逻辑: - // - Claude Max 缓存计费策略 - EnableClaudePath bool - - // 长上下文计费(仅 Gemini 路径需要) - LongContextThreshold int - LongContextMultiplier float64 -} - -// RecordUsage 记录使用量并扣费(或更新订阅用量) -func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { - return s.recordUsageCore(ctx, &recordUsageCoreInput{ - Result: input.Result, - APIKey: input.APIKey, - User: input.User, - Account: input.Account, - Subscription: input.Subscription, - InboundEndpoint: input.InboundEndpoint, - UpstreamEndpoint: input.UpstreamEndpoint, - UserAgent: input.UserAgent, - IPAddress: input.IPAddress, - RequestPayloadHash: input.RequestPayloadHash, - ForceCacheBilling: input.ForceCacheBilling, - APIKeyService: input.APIKeyService, - ChannelUsageFields: input.ChannelUsageFields, - }, &recordUsageOpts{ - EnableClaudePath: true, - }) -} -// RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) -type RecordUsageLongContextInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription // 可选:订阅信息 - InboundEndpoint string // 入站端点(客户端请求路径) - UpstreamEndpoint string // 上游端点(标准化后的上游路径) - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - RequestPayloadHash string // 请求体语义哈希,用于降低 request_id 误复用时的静默误去重风险 - LongContextThreshold int // 长上下文阈值(如 200000) - LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) - ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) - APIKeyService APIKeyQuotaUpdater // API Key 配额服务(可选) - - ChannelUsageFields // 渠道映射信息(由 handler 在 Forward 前解析) -} + for i := len(systemPaths) - 1; i >= 0 && remaining > 0; i-- { + path := systemPaths[i] + if !gjson.GetBytes(out, path).Exists() { + continue + } + next, ok := deleteJSONPathBytes(out, path) + if !ok { + continue + } + out = next + modified = true + remaining-- + } -// RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) -func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { - return s.recordUsageCore(ctx, &recordUsageCoreInput{ - Result: input.Result, - APIKey: input.APIKey, - User: input.User, - Account: input.Account, - Subscription: input.Subscription, - InboundEndpoint: input.InboundEndpoint, - UpstreamEndpoint: input.UpstreamEndpoint, - UserAgent: input.UserAgent, - IPAddress: input.IPAddress, - RequestPayloadHash: input.RequestPayloadHash, - ForceCacheBilling: input.ForceCacheBilling, - APIKeyService: input.APIKeyService, - ChannelUsageFields: input.ChannelUsageFields, - }, &recordUsageOpts{ - LongContextThreshold: input.LongContextThreshold, - LongContextMultiplier: input.LongContextMultiplier, - }) + if modified { + return out + } + return body } -// recordUsageCoreInput 是 recordUsageCore 的公共输入字段,从两种输入结构体中提取。 -type recordUsageCoreInput struct { - Result *ForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - InboundEndpoint string - UpstreamEndpoint string - UserAgent string - IPAddress string - RequestPayloadHash string - ForceCacheBilling bool - APIKeyService APIKeyQuotaUpdater - ChannelUsageFields +// injectAnthropicCacheControlTTL1h 将已有 ephemeral cache_control 块的 ttl 强制写为 1h。 +// 仅修改已经存在的 cache_control,不新增缓存断点。 +func injectAnthropicCacheControlTTL1h(body []byte) []byte { + return forceEphemeralCacheControlTTL(body, cacheTTLTarget1h) } -// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 -// opts 中的字段控制两者之间的差异行为: -// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 -// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext -func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { - result := input.Result - apiKey := input.APIKey - user := input.User - account := input.Account - subscription := input.Subscription - - // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens - // 用于粘性会话切换时的特殊计费处理 - if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", - result.Usage.InputTokens, account.ID) - result.Usage.CacheReadInputTokens += result.Usage.InputTokens - result.Usage.InputTokens = 0 - } - - // Cache TTL Override: 确保计费时 token 分类与账号设置一致。 - // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 - cacheTTLOverridden := false - if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { - applyCacheTTLOverride(&result.Usage, overrideTarget) - cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 - } - - // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := 1.0 - if s.cfg != nil { - multiplier = s.cfg.Default.RateMultiplier +func forceEphemeralCacheControlTTL(body []byte, ttl string) []byte { + if len(body) == 0 || ttl == "" { + return body } - if apiKey.GroupID != nil && apiKey.Group != nil { - groupDefault := apiKey.Group.RateMultiplier - multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) + out := body + var paths []string + addPath := func(path string, value gjson.Result) { + cc := value.Get("cache_control") + if !cc.Exists() || cc.Get("type").String() != "ephemeral" { + return + } + if cc.Get("ttl").String() == ttl { + return + } + paths = append(paths, path+".cache_control.ttl") } - imageMultiplier := resolveImageRateMultiplier(apiKey, multiplier) - // 确定计费模型 - billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) - if input.BillingModelSource == BillingModelSourceChannelMapped && input.ChannelMappedModel != "" { - billingModel = input.ChannelMappedModel - } - if input.BillingModelSource == BillingModelSourceRequested && input.OriginalModel != "" { - billingModel = input.OriginalModel + if topCC := gjson.GetBytes(body, "cache_control"); topCC.Exists() && topCC.Get("type").String() == "ephemeral" && topCC.Get("ttl").String() != ttl { + paths = append(paths, "cache_control.ttl") } - // 确定 RequestedModel(渠道映射前的原始模型) - requestedModel := result.Model - if input.OriginalModel != "" { - requestedModel = input.OriginalModel + system := gjson.GetBytes(body, "system") + if system.IsArray() { + idx := -1 + system.ForEach(func(_, block gjson.Result) bool { + idx++ + addPath(fmt.Sprintf("system.%d", idx), block) + return true + }) } - // 计算费用 - cost := s.calculateRecordUsageCost(ctx, result, apiKey, billingModel, multiplier, imageMultiplier, opts) - - // 判断计费方式:订阅模式 vs 余额模式 - isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() - billingType := BillingTypeBalance - if isSubscriptionBilling { - billingType = BillingTypeSubscription + messages := gjson.GetBytes(body, "messages") + if messages.IsArray() { + msgIdx := -1 + messages.ForEach(func(_, msg gjson.Result) bool { + msgIdx++ + content := msg.Get("content") + if !content.IsArray() { + return true + } + contentIdx := -1 + content.ForEach(func(_, block gjson.Result) bool { + contentIdx++ + addPath(fmt.Sprintf("messages.%d.content.%d", msgIdx, contentIdx), block) + return true + }) + return true + }) } - // 创建使用日志 - accountRateMultiplier := account.BillingRateMultiplier() - usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, - requestedModel, multiplier, imageMultiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) - - // 计算账号统计定价费用(使用最终上游模型匹配自定义规则) - if apiKey.GroupID != nil { - applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService, - account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model, - // Anthropic's input_tokens excludes cache_read and cache_creation (billed separately); - // OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason. - UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - }, - cost.TotalCost, - ) + tools := gjson.GetBytes(body, "tools") + if tools.IsArray() { + idx := -1 + tools.ForEach(func(_, tool gjson.Result) bool { + idx++ + addPath(fmt.Sprintf("tools.%d", idx), tool) + return true + }) } - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) - s.deferredService.ScheduleLastUsedUpdate(account.ID) - return nil + for _, path := range paths { + if next, err := sjson.SetBytes(out, path, ttl); err == nil { + out = next + } } + return out +} - requestID := usageLog.RequestID - _, billingErr := applyUsageBilling(ctx, requestID, usageLog, &postUsageBillingParams{ - Cost: cost, - User: user, - APIKey: apiKey, - Account: account, - Subscription: subscription, - RequestPayloadHash: resolveUsageBillingPayloadFingerprint(ctx, input.RequestPayloadHash), - IsSubscriptionBill: isSubscriptionBilling, - AccountRateMultiplier: accountRateMultiplier, - APIKeyService: input.APIKeyService, - }, s.billingDeps(), s.usageBillingRepo) - - if billingErr != nil { - return billingErr +func (s *GatewayService) shouldInjectAnthropicCacheTTL1h(ctx context.Context, account *Account) bool { + if account == nil || !account.IsAnthropicOAuthOrSetupToken() || s == nil || s.settingService == nil { + return false } - writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") - - return nil + return s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) } -// calculateRecordUsageCost 根据请求类型和选项计算费用。 -func (s *GatewayService) calculateRecordUsageCost( - ctx context.Context, - result *ForwardResult, - apiKey *APIKey, - billingModel string, - multiplier float64, - imageMultiplier float64, - opts *recordUsageOpts, -) *CostBreakdown { - // 图片生成计费 - if result.ImageCount > 0 { - return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier) - } - - // Token 计费 - return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) -} +// Forward 转发请求到Claude API +func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { + startTime := time.Now() + if parsed == nil { + return nil, fmt.Errorf("parse request: empty request") + } -// resolveChannelPricing 检查指定模型是否存在渠道级别定价。 -// 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 -func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { - if s.resolver == nil || apiKey.Group == nil { - return nil + // 阶段1: 快速退出路径(web search、passthrough、bedrock) + if result, err := s.checkForwardEarlyRoutes(ctx, c, account, parsed, startTime); result != nil || err != nil { + return result, err } - gid := apiKey.Group.ID - resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) - if resolved.Source == PricingSourceChannel { - return resolved + + // 阶段2: 请求体转换 + 凭证获取(Beta策略、Claude Code伪装、模型映射、cache控制) + fc, err := s.prepareForwardBody(ctx, c, account, parsed, startTime) + if err != nil { + return nil, err } - return nil -} -// calculateImageCost 计算图片生成费用:渠道级别定价优先,否则走按次计费。 -func (s *GatewayService) calculateImageCost( - ctx context.Context, - result *ForwardResult, - apiKey *APIKey, - billingModel string, - multiplier float64, -) *CostBreakdown { - if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - } - gid := apiKey.Group.ID - cost, err := s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: result.ImageCount, - SizeTier: result.ImageSize, - RateMultiplier: multiplier, - Resolver: s.resolver, - Resolved: resolved, - }) + body := fc.body + shouldMimicClaudeCode := fc.mimicClaudeCode + reqModel := fc.reqModel + reqStream := fc.reqStream + token := fc.token + tokenType := fc.tokenType + proxyURL := fc.proxyURL + tlsProfile := fc.tlsProfile + + // 重试循环 + var resp *http.Response + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, reqStream) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseUpstreamCtx() if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) - return &CostBreakdown{ActualCost: 0} + return nil, err } - return cost - } - var groupConfig *ImagePriceConfig - if apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, + // 发送请求 + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, tlsProfile) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) } - } - return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) -} -// calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。 -func (s *GatewayService) calculateTokenCost( - ctx context.Context, - result *ForwardResult, - apiKey *APIKey, - billingModel string, - multiplier float64, - opts *recordUsageOpts, -) *CostBreakdown { - tokens := UsageTokens{ - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - } - - var cost *CostBreakdown - var err error + // 优先检测thinking block签名错误(400)并重试一次 + if resp.StatusCode == 400 { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr == nil { + _ = resp.Body.Close() - // 优先尝试渠道定价 → CalculateCostUnified - if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { - gid := apiKey.Group.ID - cost, err = s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - Resolver: s.resolver, - Resolved: resolved, - }) - } else if opts.LongContextThreshold > 0 { - // 长上下文双倍计费(如 Gemini 200K 阈值) - cost, err = s.billingService.CalculateCostWithLongContext( - billingModel, tokens, multiplier, - opts.LongContextThreshold, opts.LongContextMultiplier, - ) - } else { - cost, err = s.billingService.CalculateCost(billingModel, tokens, multiplier) - } - if err != nil { - logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) - return &CostBreakdown{ActualCost: 0} - } - return cost -} + if s.shouldRectifySignatureError(ctx, account, respBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "signature_error", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) -// buildRecordUsageLog 构建使用日志并设置计费模式。 -func (s *GatewayService) buildRecordUsageLog( - ctx context.Context, - input *recordUsageCoreInput, - result *ForwardResult, - apiKey *APIKey, - user *User, - account *Account, - subscription *UserSubscription, - requestedModel string, - multiplier float64, - imageMultiplier float64, - accountRateMultiplier float64, - billingType int8, - cacheTTLOverridden bool, - cost *CostBreakdown, - opts *recordUsageOpts, -) *UsageLog { - durationMs := int(result.Duration.Milliseconds()) - requestID := resolveUsageBillingRequestID(ctx, result.RequestID) - usageLog := &UsageLog{ - UserID: user.ID, - APIKeyID: apiKey.ID, - AccountID: account.ID, - RequestID: requestID, - Model: result.Model, - RequestedModel: requestedModel, - UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model), - ReasoningEffort: result.ReasoningEffort, - InboundEndpoint: optionalTrimmedStringPtr(input.InboundEndpoint), - UpstreamEndpoint: optionalTrimmedStringPtr(input.UpstreamEndpoint), - InputTokens: result.Usage.InputTokens, - OutputTokens: result.Usage.OutputTokens, - CacheCreationTokens: result.Usage.CacheCreationInputTokens, - CacheReadTokens: result.Usage.CacheReadInputTokens, - CacheCreation5mTokens: result.Usage.CacheCreation5mTokens, - CacheCreation1hTokens: result.Usage.CacheCreation1hTokens, - ImageOutputTokens: result.Usage.ImageOutputTokens, - RateMultiplier: multiplier, - AccountRateMultiplier: &accountRateMultiplier, - BillingType: billingType, - BillingMode: resolveBillingMode(result, cost), - Stream: result.Stream, - DurationMs: &durationMs, - FirstTokenMs: result.FirstTokenMs, - ImageCount: result.ImageCount, - ImageSize: optionalTrimmedStringPtr(result.ImageSize), - CacheTTLOverridden: cacheTTLOverridden, - ChannelID: optionalInt64Ptr(input.ChannelID), - ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), - UserAgent: optionalTrimmedStringPtr(input.UserAgent), - IPAddress: optionalTrimmedStringPtr(input.IPAddress), - GroupID: apiKey.GroupID, - SubscriptionID: optionalSubscriptionID(subscription), - CreatedAt: time.Now(), - } - if result.ImageCount > 0 { - usageLog.RateMultiplier = imageMultiplier - } - if cost != nil { - usageLog.InputCost = cost.InputCost - usageLog.OutputCost = cost.OutputCost - usageLog.ImageOutputCost = cost.ImageOutputCost - usageLog.CacheCreationCost = cost.CacheCreationCost - usageLog.CacheReadCost = cost.CacheReadCost - usageLog.TotalCost = cost.TotalCost - usageLog.ActualCost = cost.ActualCost - } - - return usageLog -} + looksLikeToolSignatureError := func(msg string) bool { + m := strings.ToLower(msg) + return strings.Contains(m, "tool_use") || + strings.Contains(m, "tool_result") || + strings.Contains(m, "functioncall") || + strings.Contains(m, "function_call") || + strings.Contains(m, "functionresponse") || + strings.Contains(m, "function_response") + } -// resolveBillingMode 根据计费结果和请求类型确定计费模式。 -func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { - var mode string - switch { - case cost != nil && cost.BillingMode != "": - mode = cost.BillingMode - case result.ImageCount > 0: - mode = string(BillingModeImage) - default: - mode = string(BillingModeToken) - } - return &mode -} + // 避免在重试预算已耗尽时再发起额外请求 + if time.Since(retryStart) >= maxRetryElapsed { + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + break + } + logger.LegacyPrintf("service.gateway", "[warn] Account %d: thinking blocks have invalid signature, retrying with filtered blocks", account.ID) -func optionalSubscriptionID(subscription *UserSubscription) *int64 { - if subscription != nil { - return &subscription.ID - } - return nil -} + // Conservative two-stage fallback: + // 1) Disable thinking + thinking->text (preserve content) + // 2) Only if upstream still errors AND error message points to tool/function signature issues: + // also downgrade tool_use/tool_result blocks to text. -// ResolveChannelMapping 委托渠道服务解析模型映射 -func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { - if s.channelService == nil { - return ChannelMappingResult{MappedModel: model} - } - return s.channelService.ResolveChannelMapping(ctx, groupID, model) -} + filteredBody := FilterThinkingBlocksForRetry(body) + retryCtx, releaseRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + retryReq, buildErr := s.buildUpstreamRequest(retryCtx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx() + if buildErr == nil { + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, tlsProfile) + if retryErr == nil { + if retryResp.StatusCode < 400 { + logger.LegacyPrintf("service.gateway", "Account %d: thinking block retry succeeded (blocks downgraded)", account.ID) + resp = retryResp + break + } + + retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) + _ = retryResp.Body.Close() + if retryReadErr == nil && retryResp.StatusCode == 400 && s.isSignatureErrorPattern(ctx, account, retryRespBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: retryResp.StatusCode, + UpstreamRequestID: retryResp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(retryReq.URL.String()), + Kind: "signature_retry_thinking", + Message: extractUpstreamErrorMessage(retryRespBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(retryRespBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + msg2 := extractUpstreamErrorMessage(retryRespBody) + if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { + logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) + filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) + retryCtx2, releaseRetryCtx2 := detachStreamUpstreamContext(ctx, reqStream) + retryReq2, buildErr2 := s.buildUpstreamRequest(retryCtx2, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseRetryCtx2() + if buildErr2 == nil { + retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, tlsProfile) + if retryErr2 == nil { + resp = retryResp2 + break + } + if retryResp2 != nil && retryResp2.Body != nil { + _ = retryResp2.Body.Close() + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(retryReq2.URL.String()), + Kind: "signature_retry_tools_request_error", + Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), + }) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) + } + } + } -// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用) -func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte { - return ReplaceModelInBody(body, newModel) -} + // Fall back to the original retry response context. + resp = &http.Response{ + StatusCode: retryResp.StatusCode, + Header: retryResp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(retryRespBody)), + } + break + } + if retryResp != nil && retryResp.Body != nil { + _ = retryResp.Body.Close() + } + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry failed: %v", account.ID, retryErr) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry build request failed: %v", account.ID, buildErr) + } -// IsModelRestricted 检查模型是否被渠道限制 -func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { - if s.channelService == nil { - return false - } - return s.channelService.IsModelRestricted(ctx, groupID, model) -} + // Retry failed: restore original response body and continue handling. + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + break + } + // 不是签名错误(或整流器已关闭),继续检查 budget 约束 + errMsg := extractUpstreamErrorMessage(respBody) + if isThinkingBudgetConstraintError(errMsg) && s.settingService.IsBudgetRectifierEnabled(ctx) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "budget_constraint_error", + Message: errMsg, + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) -// ResolveChannelMappingAndRestrict 解析渠道映射。 -// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。 -func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { - if s.channelService == nil { - return ChannelMappingResult{MappedModel: model}, false - } - return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model) -} + rectifiedBody, applied := RectifyThinkingBudget(body) + if applied && time.Since(retryStart) < maxRetryElapsed { + logger.LegacyPrintf("service.gateway", "Account %d: detected budget_tokens constraint error, retrying with rectified budget (budget_tokens=%d, max_tokens=%d)", account.ID, BudgetRectifyBudgetTokens, BudgetRectifyMaxTokens) + budgetRetryCtx, releaseBudgetRetryCtx := detachStreamUpstreamContext(ctx, reqStream) + budgetRetryReq, buildErr := s.buildUpstreamRequest(budgetRetryCtx, c, account, rectifiedBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) + releaseBudgetRetryCtx() + if buildErr == nil { + budgetRetryResp, retryErr := s.httpUpstream.DoWithTLS(budgetRetryReq, proxyURL, account.ID, account.Concurrency, tlsProfile) + if retryErr == nil { + resp = budgetRetryResp + break + } + if budgetRetryResp != nil && budgetRetryResp.Body != nil { + _ = budgetRetryResp.Body.Close() + } + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry failed: %v", account.ID, retryErr) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: budget rectifier retry build failed: %v", account.ID, buildErr) + } + } + } -// checkChannelPricingRestriction 根据渠道计费基准检查模型是否受定价列表限制。 -// 供调度阶段预检查(requested / channel_mapped)。 -// upstream 需逐账号检查,此处返回 false。 -func (s *GatewayService) checkChannelPricingRestriction(ctx context.Context, groupID *int64, requestedModel string) bool { - if groupID == nil || s.channelService == nil || requestedModel == "" { - return false - } - mapping := s.channelService.ResolveChannelMapping(ctx, *groupID, requestedModel) - billingModel := billingModelForRestriction(mapping.BillingModelSource, requestedModel, mapping.MappedModel) - if billingModel == "" { - return false - } - return s.channelService.IsModelRestricted(ctx, *groupID, billingModel) -} + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + } + } -// billingModelForRestriction 根据计费基准确定限制检查使用的模型。 -// upstream 返回空(需逐账号检查)。 -func billingModelForRestriction(source, requestedModel, channelMappedModel string) string { - switch source { - case BillingModelSourceRequested: - return requestedModel - case BillingModelSourceUpstream: - return "" - case BillingModelSourceChannelMapped: - return channelMappedModel - default: - return channelMappedModel - } -} + // 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } -// isUpstreamModelRestrictedByChannel 检查账号映射后的上游模型是否受渠道定价限制。 -// 仅在 BillingModelSource="upstream" 且 RestrictModels=true 时由调度循环调用。 -func (s *GatewayService) isUpstreamModelRestrictedByChannel(ctx context.Context, groupID int64, account *Account, requestedModel string) bool { - if s.channelService == nil { - return false - } - upstreamModel := resolveAccountUpstreamModel(account, requestedModel) - if upstreamModel == "" { - return false - } - return s.channelService.IsModelRestricted(ctx, groupID, upstreamModel) -} + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } -// resolveAccountUpstreamModel 确定账号将请求模型映射为什么上游模型。 -func resolveAccountUpstreamModel(account *Account, requestedModel string) string { - if account.Platform == PlatformAntigravity { - return mapAntigravityModel(account, requestedModel) - } - return account.GetMappedModel(requestedModel) -} + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + // 最后一次尝试也失败,跳出循环处理重试耗尽 + break + } -// needsUpstreamChannelRestrictionCheck 判断是否需要在调度循环中逐账号检查上游模型的渠道限制。 -func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Context, groupID *int64) bool { - if groupID == nil || s.channelService == nil { - return false - } - ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) - if err != nil { - slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err) - return false + // 不需要重试(成功或不可重试的错误),跳出循环 + // DEBUG: 输出响应 headers(用于检测 rate limit 信息) + if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) + for k, v := range resp.Header { + logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) + } + } + break } - if ch == nil || !ch.RestrictModels { - return false + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") } - return ch.BillingModelSource == BillingModelSourceUpstream -} + defer func() { _ = resp.Body.Close() }() -// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。 -// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用, -// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。 -func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool { - if groupID == nil { - return false - } - if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) { - return false - } - return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel) -} + // 处理重试耗尽的情况 + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) -// ForwardCountTokens 转发 count_tokens 请求到上游 API -// 特点:不记录使用量、仅支持非流式响应 -func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { - if parsed == nil { - s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") - return fmt.Errorf("parse request: empty request") - } + // 调试日志:打印重试耗尽后的错误响应 + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { - passthroughBody := parsed.Body - if reqModel := parsed.Model; reqModel != "" { - if mappedModel := account.GetMappedModel(reqModel); mappedModel != reqModel { - passthroughBody = s.replaceModelInBody(passthroughBody, mappedModel) - logger.LegacyPrintf("service.gateway", "CountTokens passthrough model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), } } - return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, passthroughBody) - } - - // Bedrock 不支持 count_tokens 端点 - if account != nil && account.IsBedrock() { - s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for Bedrock") - return nil + return s.handleRetryExhaustedError(ctx, resp, c, account) } - body := parsed.Body - reqModel := parsed.Model + // 处理可切换账号的错误 + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) - // Pre-filter: strip empty text blocks to prevent upstream 400. - body = StripEmptyTextBlocks(body) + // 调试日志:打印上游错误响应 + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - isClaudeCodeCT := IsClaudeCodeClient(ctx) || isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) - shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCodeCT + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + if resp.StatusCode >= 400 { + // 可选:对部分 400 触发 failover(默认关闭以保持语义) + if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if readErr != nil { + // ReadAll failed, fall back to normal error handling without consuming the stream + return s.handleErrorResponse(ctx, resp, c, account) + } + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) - if shouldMimicClaudeCode { - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} - body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) + if s.shouldFailoverOn400(respBody) { + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover_on_400", + Message: upstreamMsg, + Detail: upstreamDetail, + }) - body = s.rewriteMessageCacheControlIfEnabled(ctx, body) - if rw := buildToolNameRewriteFromBody(body); rw != nil { - body = applyToolNameRewriteToBody(body, rw) - } else { - body = applyToolsLastCacheBreakpoint(body) + if s.cfg.Gateway.LogUpstreamErrorBody { + logger.LegacyPrintf("service.gateway", + "Account %d: 400 error, attempting failover: %s", + account.ID, + truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), + ) + } else { + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) + } + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } } + return s.handleErrorResponse(ctx, resp, c, account) } - // Antigravity 账户不支持 count_tokens,返回 404 让客户端 fallback 到本地估算。 - // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 - if account.Platform == PlatformAntigravity { - s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported for this platform") - return nil - } + // 阶段4: 处理正常响应(流式/非流式 + 结果构建) + return s.processForwardResponse(ctx, c, account, resp, fc, parsed) +} - // 应用模型映射: - // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 - // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) - if reqModel != "" { - mappedModel := reqModel - mappingSource := "" - if account.Type == AccountTypeAPIKey { - mappedModel = account.GetMappedModel(reqModel) - if mappedModel != reqModel { - mappingSource = "account" - } - } - if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { - normalized := claude.NormalizeModelID(reqModel) - if normalized != reqModel { - mappedModel = normalized - mappingSource = "prefix" - } - } - if mappedModel != reqModel { - body = s.replaceModelInBody(body, mappedModel) - reqModel = mappedModel - logger.LegacyPrintf("service.gateway", "CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) - } - } +type anthropicPassthroughForwardInput struct { + Body []byte + RequestModel string + OriginalModel string + RequestStream bool + StartTime time.Time +} + +func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + originalModel string, + reqStream bool, + startTime time.Time, +) (*ForwardResult, error) { + return s.forwardAnthropicAPIKeyPassthroughWithInput(ctx, c, account, anthropicPassthroughForwardInput{ + Body: body, + RequestModel: reqModel, + OriginalModel: originalModel, + RequestStream: reqStream, + StartTime: startTime, + }) +} - // 获取凭证 +func (s *GatewayService) forwardAnthropicAPIKeyPassthroughWithInput( + ctx context.Context, + c *gin.Context, + account *Account, + input anthropicPassthroughForwardInput, +) (*ForwardResult, error) { token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") - return err + return nil, err } - - // 构建上游请求 - upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode) - if err != nil { - s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - return err + if tokenType != "apikey" { + return nil, fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) } - // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { - proxyURL = account.Proxy.URL() - } + proxyURL = account.Proxy.URL() } - // 发送请求 - resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) - if err != nil { - setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") - return fmt.Errorf("upstream request failed: %w", err) - } + logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", + account.ID, account.Name, input.RequestModel, input.RequestStream) - // 读取响应体 - countTokensTooLarge := func(c *gin.Context) { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - } - respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) - _ = resp.Body.Close() - if err != nil { - if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") - } - return err + if c != nil { + c.Set("anthropic_passthrough", true) } + // Pre-filter: strip empty text blocks (including nested in tool_result) to prevent upstream 400. + input.Body = StripEmptyTextBlocks(input.Body) - // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) - if resp.StatusCode == 400 && s.shouldRectifySignatureError(ctx, account, respBody) { - logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, input.Body) - filteredBody := FilterThinkingBlocksForRetry(body) - retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) - if buildErr == nil { - retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) - if retryErr == nil { - resp = retryResp - respBody, err = ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) - _ = resp.Body.Close() - if err != nil { - if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") - } - return err - } - } + var resp *http.Response + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, input.RequestStream) + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(upstreamCtx, c, account, input.Body, token) + releaseUpstreamCtx() + if err != nil { + return nil, err } - } - // 处理错误响应 - if resp.StatusCode >= 400 { - // 标记账号状态(429/529等) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() } - upstreamDetail = truncateString(string(respBody), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - - // 记录上游错误摘要便于排障(不回显请求内容) - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - logger.LegacyPrintf("service.gateway", - "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", - resp.StatusCode, - account.ID, - account.Platform, - account.Type, - truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), - ) + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) } - // 返回简化的错误响应 - errMsg := "Upstream request failed" - switch resp.StatusCode { - case 429: - errMsg = "Rate limit exceeded" - case 529: - errMsg = "Service overloaded" - } - s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) - if upstreamMsg == "" { - return fmt.Errorf("upstream error: %d", resp.StatusCode) - } - return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) - } + // 透传分支禁止 400 请求体降级重试(该重试会改写请求体) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } - // 透传成功响应 - c.Data(resp.StatusCode, "application/json", respBody) - return nil -} + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } -func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx context.Context, c *gin.Context, account *Account, body []byte) error { - token, tokenType, err := s.GetAccessToken(ctx, account) - if err != nil { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") - return err - } - if tokenType != "apikey" { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Invalid account token type") - return fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) - } + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Passthrough: true, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "Anthropic passthrough account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } - upstreamReq, err := s.buildCountTokensRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) - if err != nil { - s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") - return err + break } - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") } + defer func() { _ = resp.Body.Close() }() - resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) - if err != nil { - setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), - Passthrough: true, - Kind: "request_error", - Message: sanitizeUpstreamErrorMessage(err.Error()), - }) - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") - return fmt.Errorf("upstream request failed: %w", err) - } + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) - countTokensTooLarge := func(c *gin.Context) { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") - } - respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, countTokensTooLarge) - _ = resp.Body.Close() - if err != nil { - if !errors.Is(err, ErrUpstreamResponseBodyTooLarge) { - s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } } - return err + return s.handleRetryExhaustedError(ctx, resp, c, account) } - if resp.StatusCode >= 400 { - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - } + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - - // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。 - // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。 - // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 - if isCountTokensUnsupported404(resp.StatusCode, respBody) { - logger.LegacyPrintf("service.gateway", - "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s", - account.ID, account.Name, truncateString(upstreamMsg, 512)) - s.countTokensError(c, http.StatusNotFound, "not_found_error", "count_tokens endpoint is not supported by upstream") - return nil - } + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(respBody), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + s.handleFailoverSideEffects(ctx, resp, account) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), - UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), Passthrough: true, - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + if resp.StatusCode >= 400 { + return s.handleErrorResponse(ctx, resp, c, account) + } - errMsg := "Upstream request failed" - switch resp.StatusCode { - case 429: - errMsg = "Rate limit exceeded" - case 529: - errMsg = "Service overloaded" + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if input.RequestStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, input.StartTime, input.RequestModel) + if err != nil { + return nil, err } - s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) - if upstreamMsg == "" { - return fmt.Errorf("upstream error: %d", resp.StatusCode) + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account) + if err != nil { + return nil, err } - return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + if usage == nil { + usage = &ClaudeUsage{} } - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) - if contentType == "" { - contentType = "application/json" - } - c.Data(resp.StatusCode, contentType, respBody) - return nil + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: input.OriginalModel, + UpstreamModel: input.RequestModel, + Stream: input.RequestStream, + Duration: time.Since(input.StartTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil } -func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( +func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( ctx context.Context, c *gin.Context, account *Account, body []byte, token string, ) (*http.Request, error) { - targetURL := claudeAPICountTokensURL + targetURL := claudeAPIURL baseURL := account.GetBaseURL() if baseURL != "" { validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err } - targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" + targetURL = validatedURL + "/v1/messages?beta=true" } req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) @@ -9147,433 +2808,493 @@ func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( } } + // 覆盖入站鉴权残留,并注入上游认证 req.Header.Del("authorization") req.Header.Del("x-api-key") req.Header.Del("x-goog-api-key") req.Header.Del("cookie") - req.Header.Set("x-api-key", token) + setHeaderRaw(req.Header, "x-api-key", token) - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") + if getHeaderRaw(req.Header, "content-type") == "" { + setHeaderRaw(req.Header, "content-type", "application/json") } - if req.Header.Get("anthropic-version") == "" { - req.Header.Set("anthropic-version", "2023-06-01") + if getHeaderRaw(req.Header, "anthropic-version") == "" { + setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") } return req, nil } -// buildCountTokensRequest 构建 count_tokens 上游请求 -func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { - // 确定目标 URL - targetURL := claudeAPICountTokensURL - if account.Type == AccountTypeAPIKey { - baseURL := account.GetBaseURL() - if baseURL != "" { - validatedURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return nil, err - } - targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" - } - } else if account.IsCustomBaseURLEnabled() { - customURL := account.GetCustomBaseURL() - if customURL == "" { - return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) - } - validatedURL, err := s.validateUpstreamBaseURL(customURL) - if err != nil { - return nil, err - } - targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) +func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) } - clientHeaders := http.Header{} - if c != nil && c.Request != nil { - clientHeaders = c.Request.Header - } + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) - // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false - if s.settingService != nil { - ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "text/event-stream" } - var ctFingerprint *Fingerprint - if account.IsOAuth() && s.identityService != nil { - fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) - if err == nil { - ctFingerprint = fp - if !ctEnableMPT { - accountUUID := account.GetExtraString("account_uuid") - if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { - body = newBody - } - } - } - } + c.Header("Content-Type", contentType) + if c.Writer.Header().Get("Cache-Control") == "" { + c.Header("Cache-Control", "no-cache") } - - // 同步 billing header cc_version 与实际发送的 User-Agent 版本 - if ctFingerprint != nil && ctEnableFP { - body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent) + if c.Writer.Header().Get("Connection") == "" { + c.Header("Connection", "keep-alive") } - if ctEnableCCH { - body = signBillingHeaderCCH(body) + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) } - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") } - // 设置认证头(保持原始大小写) - if tokenType == "oauth" { - setHeaderRaw(req.Header, "authorization", "Bearer "+token) - } else { - setHeaderRaw(req.Header, "x-api-key", token) + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + sawTerminalEvent := false + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) - // 白名单透传 headers(恢复真实 wire casing) - for key, values := range clientHeaders { - lowerKey := strings.ToLower(key) - if allowedHeaders[lowerKey] { - wireKey := resolveWireCasing(key) - for _, v := range values { - addHeaderRaw(req.Header, wireKey, v) - } + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false } } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) - // OAuth 账号:应用指纹到请求头(受设置开关控制) - if ctEnableFP && ctFingerprint != nil { - s.identityService.ApplyFingerprint(req, ctFingerprint) + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C } - // 确保必要的 headers 存在(保持原始大小写) - if getHeaderRaw(req.Header, "content-type") == "" { - setHeaderRaw(req.Header, "content-type", "application/json") + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 + flusher.Flush() + } + if !sawTerminalEvent { + if clientDisconnected && streamInterval > 0 { + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) >= streamInterval { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") + } + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) + } + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + + line := ev.line + if data, ok := extractAnthropicSSEDataLine(line); ok { + trimmed := strings.TrimSpace(data) + if anthropicStreamEventIsTerminal("", trimmed) { + sawTerminalEvent = true + } + if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsagePassthrough(data, usage) + } else { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") && anthropicStreamEventIsTerminal(strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")), "") { + sawTerminalEvent = true + } + } + + if !clientDisconnected { + restored := string(reverseToolNamesIfPresent(c, []byte(line))) + if _, err := io.WriteString(w, restored); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if _, err := io.WriteString(w, "\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if line == "" { + // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") + } + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } } - if getHeaderRaw(req.Header, "anthropic-version") == "" { - setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") +} + +func extractAnthropicSSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false } - if tokenType == "oauth" { - applyClaudeOAuthHeaderDefaults(req) + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != '\t' { + break + } + start++ } + return line[start:], true +} - // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules - ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID)) +func (s *GatewayService) parseSSEUsagePassthrough(data string, usage *ClaudeUsage) { + if usage == nil || data == "" || data == "[DONE]" { + return + } - // OAuth 账号:处理 anthropic-beta header - if tokenType == "oauth" { - if mimicClaudeCode { - applyClaudeCodeMimicHeaders(req, false) + parsed := gjson.Parse(data) + switch parsed.Get("type").String() { + case "message_start": + msgUsage := parsed.Get("message.usage") + if msgUsage.Exists() { + usage.InputTokens = int(msgUsage.Get("input_tokens").Int()) + usage.CacheCreationInputTokens = int(msgUsage.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(msgUsage.Get("cache_read_input_tokens").Int()) - incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - requiredBetas := append(claude.FullClaudeCodeMimicryBetas(), claude.BetaTokenCounting) - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, ctEffectiveDropSet)) - } else { - clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") - if clientBetaHeader == "" { - setHeaderRaw(req.Header, "anthropic-beta", claude.CountTokensBetaHeader) - } else { - beta := s.getBetaHeader(modelID, clientBetaHeader) - if !strings.Contains(beta, claude.BetaTokenCounting) { - beta = beta + "," + claude.BetaTokenCounting - } - setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(beta, ctEffectiveDropSet)) + // 保持与通用解析一致:message_start 允许覆盖 5m/1h 明细(包括 0)。 + cc5m := msgUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := msgUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + case "message_delta": + deltaUsage := parsed.Get("usage") + if deltaUsage.Exists() { + if v := deltaUsage.Get("input_tokens").Int(); v > 0 { + usage.InputTokens = int(v) + } + if v := deltaUsage.Get("output_tokens").Int(); v > 0 { + usage.OutputTokens = int(v) } - } - } else { - // API-key accounts: apply beta policy filter to strip controlled tokens - if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { - setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, ctEffectiveDropSet)) - } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { - // API-key:与 messages 同步的按需 beta 注入(默认关闭) - if requestNeedsBetaFeatures(body) { - if beta := defaultAPIKeyBetaHeader(body); beta != "" { - setHeaderRaw(req.Header, "anthropic-beta", beta) - } + if v := deltaUsage.Get("cache_creation_input_tokens").Int(); v > 0 { + usage.CacheCreationInputTokens = int(v) + } + if v := deltaUsage.Get("cache_read_input_tokens").Int(); v > 0 { + usage.CacheReadInputTokens = int(v) } - } - } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + cc5m := deltaUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := deltaUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() && cc5m.Int() > 0 { + usage.CacheCreation5mTokens = int(cc5m.Int()) + } + if cc1h.Exists() && cc1h.Int() > 0 { + usage.CacheCreation1hTokens = int(cc1h.Int()) } } } - if c != nil && tokenType == "oauth" { - c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) - } - if s.debugClaudeMimicEnabled() { - logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + if usage.CacheReadInputTokens == 0 { + if cached := parsed.Get("message.usage.cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + if cached := parsed.Get("usage.cached_tokens").Int(); usage.CacheReadInputTokens == 0 && cached > 0 { + usage.CacheReadInputTokens = int(cached) + } } - - return req, nil -} - -// countTokensError 返回 count_tokens 错误响应 -func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { - c.JSON(status, gin.H{ - "type": "error", - "error": gin.H{ - "type": errType, - "message": message, - }, - }) -} - -// buildCustomRelayURL 构建自定义中继转发 URL -// 在 path 后附加 beta=true 和可选的 proxy 查询参数 -func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { - u := strings.TrimRight(baseURL, "/") + path + "?beta=true" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL := account.Proxy.URL() - if proxyURL != "" { - u += "&proxy=" + url.QueryEscape(proxyURL) + if usage.CacheCreationInputTokens == 0 { + cc5m := parsed.Get("message.usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := parsed.Get("message.usage.cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m == 0 && cc1h == 0 { + cc5m = parsed.Get("usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h = parsed.Get("usage.cache_creation.ephemeral_1h_input_tokens").Int() + } + total := cc5m + cc1h + if total > 0 { + usage.CacheCreationInputTokens = int(total) } } - return u } -func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { - if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { - normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) - if err != nil { - return "", fmt.Errorf("invalid base_url: %w", err) - } - return normalized, nil - } - normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ - AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, - RequireAllowlist: true, - AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, - }) - if err != nil { - return "", fmt.Errorf("invalid base_url: %w", err) +func parseClaudeUsageFromResponseBody(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + if len(body) == 0 { + return usage } - return normalized, nil -} -// GetAvailableModels returns the list of models available for a group -// It aggregates model_mapping keys from all schedulable accounts in the group -func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { - cacheKey := modelsListCacheKey(groupID, platform) - if s.modelsListCache != nil { - if cached, found := s.modelsListCache.Get(cacheKey); found { - if models, ok := cached.([]string); ok { - modelsListCacheHitTotal.Add(1) - return cloneStringSlice(models) - } - } + parsed := gjson.ParseBytes(body) + usageNode := parsed.Get("usage") + if !usageNode.Exists() { + return usage } - modelsListCacheMissTotal.Add(1) - var accounts []Account - var err error + usage.InputTokens = int(usageNode.Get("input_tokens").Int()) + usage.OutputTokens = int(usageNode.Get("output_tokens").Int()) + usage.CacheCreationInputTokens = int(usageNode.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(usageNode.Get("cache_read_input_tokens").Int()) - if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID) - } else { - accounts, err = s.accountRepo.ListSchedulable(ctx) + cc5m := usageNode.Get("cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := usageNode.Get("cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m > 0 || cc1h > 0 { + usage.CacheCreation5mTokens = int(cc5m) + usage.CacheCreation1hTokens = int(cc1h) } - - if err != nil || len(accounts) == 0 { - return nil + if usage.CacheCreationInputTokens == 0 && (cc5m > 0 || cc1h > 0) { + usage.CacheCreationInputTokens = int(cc5m + cc1h) } - - // Filter by platform if specified - if platform != "" { - filtered := make([]Account, 0) - for _, acc := range accounts { - if acc.Platform == platform { - filtered = append(filtered, acc) - } + if usage.CacheReadInputTokens == 0 { + if cached := usageNode.Get("cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) } - accounts = filtered } + return usage +} - // Collect unique models from all accounts - modelSet := make(map[string]struct{}) - hasAnyMapping := false - - for _, acc := range accounts { - mapping := acc.GetModelMapping() - if len(mapping) > 0 { - hasAnyMapping = true - for model := range mapping { - modelSet[model] = struct{}{} - } - } +func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) } - // If no account has model_mapping, return nil (use default) - if !hasAnyMapping { - if s.modelsListCache != nil { - s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL) - modelsListCacheStoreTotal.Add(1) - } - return nil + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) + if err != nil { + return nil, err } - // Convert to slice - models := make([]string, 0, len(modelSet)) - for model := range modelSet { - models = append(models, model) - } - sort.Strings(models) + usage := parseClaudeUsageFromResponseBody(body) - if s.modelsListCache != nil { - s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL) - modelsListCacheStoreTotal.Add(1) + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" } - return cloneStringSlice(models) + body = reverseToolNamesIfPresent(c, body) + c.Data(resp.StatusCode, contentType, body) + return usage, nil } -func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) { - if s == nil || s.modelsListCache == nil { +func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { return } - - normalizedPlatform := strings.TrimSpace(platform) - // 完整匹配时精准失效;否则按维度批量失效。 - if groupID != nil && normalizedPlatform != "" { - s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform)) + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) return } - - targetGroup := derefGroupID(groupID) - for key := range s.modelsListCache.Items() { - parts := strings.SplitN(key, "|", 2) - if len(parts) != 2 { - continue - } - groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64) - if parseErr != nil { - continue - } - if groupID != nil && groupPart != targetGroup { - continue - } - if normalizedPlatform != "" && parts[1] != normalizedPlatform { - continue - } - s.modelsListCache.Delete(key) - } -} - -// reconcileCachedTokens 兼容 Kimi 等上游: -// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens -func reconcileCachedTokens(usage map[string]any) bool { - if usage == nil { - return false - } - cacheRead, _ := usage["cache_read_input_tokens"].(float64) - if cacheRead > 0 { - return false // 已有标准字段,无需处理 + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) } - cached, _ := usage["cached_tokens"].(float64) - if cached <= 0 { - return false + if v := strings.TrimSpace(src.Get("x-request-id")); v != "" { + dst.Set("x-request-id", v) } - usage["cache_read_input_tokens"] = cached - return true } -const debugGatewayBodyDefaultFilename = "gateway_debug.log" +// forwardBedrock 转发请求到 AWS Bedrock +func (s *GatewayService) forwardBedrock( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *ParsedRequest, + startTime time.Time, +) (*ForwardResult, error) { + reqModel := parsed.Model + reqStream := parsed.Stream + body := parsed.Body -// initDebugGatewayBodyFile 初始化网关调试日志文件。 -// -// - "1"/"true" 等布尔值 → 当前目录下 gateway_debug.log -// - 已有目录路径 → 该目录下 gateway_debug.log -// - 其他 → 视为完整文件路径 -func (s *GatewayService) initDebugGatewayBodyFile(path string) { - if parseDebugEnvBool(path) { - path = debugGatewayBodyDefaultFilename + region := bedrockRuntimeRegion(account) + mappedModel, ok := ResolveBedrockModelID(account, reqModel) + if !ok { + return nil, fmt.Errorf("unsupported bedrock model: %s", reqModel) + } + if mappedModel != reqModel { + logger.LegacyPrintf("service.gateway", "[Bedrock] Model mapping: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) } - // 如果 path 指向一个已存在的目录,自动追加默认文件名 - if info, err := os.Stat(path); err == nil && info.IsDir() { - path = filepath.Join(path, debugGatewayBodyDefaultFilename) + betaHeader := "" + if c != nil && c.Request != nil { + betaHeader = c.GetHeader("anthropic-beta") } - // 确保父目录存在 - if dir := filepath.Dir(path); dir != "." { - if err := os.MkdirAll(dir, 0755); err != nil { - slog.Error("failed to create gateway debug log directory", "dir", dir, "error", err) - return - } + // 准备请求体(注入 anthropic_version/anthropic_beta,移除 Bedrock 不支持的字段,清理 cache_control) + betaTokens, err := s.resolveBedrockBetaTokensForRequest(ctx, account, betaHeader, body, mappedModel) + if err != nil { + return nil, err } - f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) + bedrockBody, err := PrepareBedrockRequestBodyWithTokens(body, mappedModel, betaTokens) if err != nil { - slog.Error("failed to open gateway debug log file", "path", path, "error", err) - return + return nil, fmt.Errorf("prepare bedrock request body: %w", err) } - s.debugGatewayBodyFile.Store(f) - slog.Info("gateway debug logging enabled", "path", path) -} -// debugLogGatewaySnapshot 将网关请求的完整快照(headers + body)写入独立的调试日志文件, -// 用于对比客户端原始请求和上游转发请求。 -// -// 启用方式(环境变量): -// -// SUB2API_DEBUG_GATEWAY_BODY=1 # 写入 gateway_debug.log -// SUB2API_DEBUG_GATEWAY_BODY=/tmp/gateway_debug.log # 写入指定路径 -// -// tag: "CLIENT_ORIGINAL" 或 "UPSTREAM_FORWARD" -func (s *GatewayService) debugLogGatewaySnapshot(tag string, headers http.Header, body []byte, extra map[string]string) { - f := s.debugGatewayBodyFile.Load() - if f == nil { - return + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() } - var buf strings.Builder - ts := time.Now().Format("2006-01-02 15:04:05.000") - fmt.Fprintf(&buf, "\n========== [%s] %s ==========\n", ts, tag) + logger.LegacyPrintf("service.gateway", "[Bedrock] 命中 Bedrock 分支: account=%d name=%s model=%s->%s stream=%v", + account.ID, account.Name, reqModel, mappedModel, reqStream) - // 1. context - if len(extra) > 0 { - fmt.Fprint(&buf, "--- context ---\n") - extraKeys := make([]string, 0, len(extra)) - for k := range extra { - extraKeys = append(extraKeys, k) + // 根据账号类型选择认证方式 + var signer *BedrockSigner + var bedrockAPIKey string + if account.IsBedrockAPIKey() { + bedrockAPIKey = account.GetCredential("api_key") + if bedrockAPIKey == "" { + return nil, fmt.Errorf("api_key not found in bedrock credentials") } - sort.Strings(extraKeys) - for _, k := range extraKeys { - fmt.Fprintf(&buf, " %s: %s\n", k, extra[k]) + } else { + signer, err = NewBedrockSignerFromAccount(account) + if err != nil { + return nil, fmt.Errorf("create bedrock signer: %w", err) } } - // 2. headers(按真实 Claude CLI wire 顺序排列,便于与抓包对比;auth 脱敏) - fmt.Fprint(&buf, "--- headers ---\n") - for _, k := range sortHeadersByWireOrder(headers) { - for _, v := range headers[k] { - fmt.Fprintf(&buf, " %s: %s\n", k, safeHeaderValueForLog(k, v)) - } + // 执行上游请求(含重试) + resp, err := s.executeBedrockUpstream(ctx, c, account, bedrockBody, mappedModel, region, reqStream, signer, bedrockAPIKey, proxyURL) + if err != nil { + return nil, err } + defer func() { _ = resp.Body.Close() }() - // 3. body(完整输出,格式化 JSON 便于 diff) - fmt.Fprint(&buf, "--- body ---\n") - if len(body) == 0 { - fmt.Fprint(&buf, " (empty)\n") + // 将 Bedrock 的 x-amzn-requestid 映射到 x-request-id, + // 使通用错误处理函数(handleErrorResponse、handleRetryExhaustedError)能正确提取 AWS request ID。 + if awsReqID := resp.Header.Get("x-amzn-requestid"); awsReqID != "" && resp.Header.Get("x-request-id") == "" { + resp.Header.Set("x-request-id", awsReqID) + } + + // 错误/failover 处理 + if resp.StatusCode >= 400 { + return s.handleBedrockUpstreamErrors(ctx, resp, c, account) + } + + // 响应处理 + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleBedrockStreamingResponse(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect } else { - var pretty bytes.Buffer - if json.Indent(&pretty, body, " ", " ") == nil { - fmt.Fprintf(&buf, " %s\n", pretty.Bytes()) - } else { - // JSON 格式化失败时原样输出 - fmt.Fprintf(&buf, " %s\n", body) + usage, err = s.handleBedrockNonStreamingResponse(ctx, resp, c, account) + if err != nil { + return nil, err } } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-amzn-requestid"), + Usage: *usage, + Model: reqModel, + UpstreamModel: mappedModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +func (e *BetaBlockedError) Error() string { return e.Message } - // 写入文件(调试用,并发写入可能交错但不影响可读性) - _, _ = f.WriteString(buf.String()) +// mergeDropSets merges the static defaultDroppedBetasSet with dynamic policy filter tokens. +// Returns defaultDroppedBetasSet directly when policySet is empty (zero allocation). +func mergeDropSets(policySet map[string]struct{}, extra ...string) map[string]struct{} { + if len(policySet) == 0 && len(extra) == 0 { + return defaultDroppedBetasSet + } + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(policySet)+len(extra)) + for t := range defaultDroppedBetasSet { + m[t] = struct{}{} + } + for t := range policySet { + m[t] = struct{}{} + } + for _, t := range extra { + m[t] = struct{}{} + } + return m } diff --git a/backend/internal/service/gateway_session.go b/backend/internal/service/gateway_session.go new file mode 100644 index 00000000000..3bb51cb66cb --- /dev/null +++ b/backend/internal/service/gateway_session.go @@ -0,0 +1,276 @@ +package service + +import ( + "context" + "crypto/sha256" + "fmt" + "log/slog" + "strconv" + "strings" + + "github.com/cespare/xxhash/v2" + "github.com/google/uuid" +) + +func shortSessionHash(sessionHash string) string { + if sessionHash == "" { + return "" + } + if len(sessionHash) <= 8 { + return sessionHash + } + return sessionHash[:8] +} + +// GenerateSessionHash 从预解析请求计算粘性会话 hash +func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + // 1. 最高优先级:从 metadata.user_id 提取 session_xxx + if parsed.MetadataUserID != "" { + uid := ParseMetadataUserID(parsed.MetadataUserID) + if uid != nil && uid.SessionID != "" { + slog.Info("sticky.hash_source", + "source", "metadata_user_id", + "session_id", uid.SessionID, + "device_id", uid.DeviceID, + "is_new_format", uid.IsNewFormat, + ) + return uid.SessionID + } + slog.Info("sticky.hash_metadata_parse_failed", + "metadata_user_id", parsed.MetadataUserID, + "parsed_nil", uid == nil, + ) + } + + // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 + cacheableContent := s.extractCacheableContent(parsed) + if cacheableContent != "" { + hash := s.hashContent(cacheableContent) + slog.Info("sticky.hash_source", + "source", "cacheable_content", + "hash", hash, + ) + return hash + } + + // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 + var combined strings.Builder + // 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash + if parsed.SessionContext != nil { + _, _ = combined.WriteString(parsed.SessionContext.ClientIP) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(NormalizeSessionUserAgent(parsed.SessionContext.UserAgent)) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) + _, _ = combined.WriteString("|") + } + if parsed.System != nil { + systemText := s.extractTextFromSystem(parsed.System) + if systemText != "" { + _, _ = combined.WriteString(systemText) + } + } + for _, msg := range parsed.Messages { + if m, ok := msg.(map[string]any); ok { + if content, exists := m["content"]; exists { + // Anthropic: messages[].content + if msgText := s.extractTextFromContent(content); msgText != "" { + _, _ = combined.WriteString(msgText) + } + } else if parts, ok := m["parts"].([]any); ok { + // Gemini: contents[].parts[].text + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + _, _ = combined.WriteString(text) + } + } + } + } + } + } + if combined.Len() > 0 { + hash := s.hashContent(combined.String()) + slog.Info("sticky.hash_source", + "source", "message_content_fallback", + "hash", hash, + "content_len", combined.Len(), + ) + return hash + } + + return "" +} + +// BindStickySession sets session -> account binding with standard TTL. +func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { + if sessionHash == "" || accountID <= 0 || s.cache == nil { + return nil + } + return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) +} + +// GetCachedSessionAccountID retrieves the account ID bound to a sticky session. +// Returns 0 if no binding exists or on error. +func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if sessionHash == "" || s.cache == nil { + return 0, nil + } + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) + if err != nil { + return 0, err + } + return accountID, nil +} + +// FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) +// 返回最长匹配的会话信息(uuid, accountID) +func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false + } + return s.digestStore.Find(groupID, prefixHash, digestChain) +} + +// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { + return nil + } + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil +} + +// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) +func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false + } + return s.digestStore.Find(groupID, prefixHash, digestChain) +} + +// SaveAnthropicSession 保存 Anthropic 会话 +func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { + return nil + } + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil +} + +func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + var builder strings.Builder + + // 检查 system 中的 cacheable 内容 + if system, ok := parsed.System.([]any); ok { + for _, part := range system { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + if text, ok := partMap["text"].(string); ok { + _, _ = builder.WriteString(text) + } + } + } + } + } + } + systemText := builder.String() + + // 检查 messages 中的 cacheable 内容 + for _, msg := range parsed.Messages { + if msgMap, ok := msg.(map[string]any); ok { + if msgContent, ok := msgMap["content"].([]any); ok { + for _, part := range msgContent { + if partMap, ok := part.(map[string]any); ok { + if cc, ok := partMap["cache_control"].(map[string]any); ok { + if cc["type"] == "ephemeral" { + return s.extractTextFromContent(msgMap["content"]) + } + } + } + } + } + } + } + + return systemText +} + +func (s *GatewayService) extractTextFromSystem(system any) string { + switch v := system.(type) { + case string: + return v + case []any: + var texts []string + for _, part := range v { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + texts = append(texts, text) + } + } + } + return strings.Join(texts, "") + } + return "" +} + +func (s *GatewayService) extractTextFromContent(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + var texts []string + for _, part := range v { + if partMap, ok := part.(map[string]any); ok { + if partMap["type"] == "text" { + if text, ok := partMap["text"].(string); ok { + texts = append(texts, text) + } + } + } + } + return strings.Join(texts, "") + } + return "" +} + +func (s *GatewayService) hashContent(content string) string { + h := xxhash.Sum64String(content) + return strconv.FormatUint(h, 36) +} + +// hashBodyForSessionSeed 为 sessionID 提供一个稳定但仅对本次请求特征化的种子。 +// 复用 SHA-256 + 截断,与 generateSessionUUID 的输入格式对齐。 +func hashBodyForSessionSeed(body []byte) string { + if len(body) == 0 { + return "" + } + sum := sha256.Sum256(body) + return fmt.Sprintf("%x", sum[:16]) +} + +// GenerateSessionUUID creates a deterministic UUID4 from a seed string. +func GenerateSessionUUID(seed string) string { + return generateSessionUUID(seed) +} + +func generateSessionUUID(seed string) string { + if seed == "" { + return uuid.NewString() + } + hash := sha256.Sum256([]byte(seed)) + bytes := hash[:16] + bytes[6] = (bytes[6] & 0x0f) | 0x40 + bytes[8] = (bytes[8] & 0x3f) | 0x80 + return fmt.Sprintf("%x-%x-%x-%x-%x", + bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) +} diff --git a/backend/internal/service/gateway_stream.go b/backend/internal/service/gateway_stream.go new file mode 100644 index 00000000000..a9472b7327c --- /dev/null +++ b/backend/internal/service/gateway_stream.go @@ -0,0 +1,829 @@ +package service + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "sync/atomic" + "syscall" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + + "github.com/gin-gonic/gin" +) + +func openAIStreamEventIsTerminal(data string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + switch gjson.Get(trimmed, "type").String() { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func anthropicStreamEventIsTerminal(eventName, data string) bool { + if strings.EqualFold(strings.TrimSpace(eventName), "message_stop") { + return true + } + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if trimmed == "[DONE]" { + return true + } + return gjson.Get(trimmed, "type").String() == "message_stop" +} + +// sanitizeStreamError 返回不含网络地址的客户端可见错误描述。 +// 默认 (*net.OpError).Error() 会拼接 Source/Addr 字段,泄露内部 IP/端口与上游 +// 服务器地址(例如 "read tcp 10.0.0.1:54321->52.1.2.3:443: read: connection +// reset by peer")。该函数只保留可识别的错误类别,原始 err 仍在调用点写入日志。 +func sanitizeStreamError(err error) string { + if err == nil { + return "" + } + switch { + case errors.Is(err, io.ErrUnexpectedEOF): + return "unexpected EOF" + case errors.Is(err, io.EOF): + return "EOF" + case errors.Is(err, context.Canceled): + return "canceled" + case errors.Is(err, context.DeadlineExceeded): + return "deadline exceeded" + case errors.Is(err, syscall.ECONNRESET): + return "connection reset by peer" + case errors.Is(err, syscall.ECONNABORTED): + return "connection aborted" + case errors.Is(err, syscall.ETIMEDOUT): + return "connection timed out" + case errors.Is(err, syscall.EPIPE): + return "broken pipe" + case errors.Is(err, syscall.ECONNREFUSED): + return "connection refused" + } + var netErr *net.OpError + if errors.As(err, &netErr) { + if netErr.Timeout() { + if netErr.Op != "" { + return netErr.Op + " timeout" + } + return "i/o timeout" + } + if netErr.Op != "" { + return netErr.Op + " network error" + } + } + return "upstream connection error" +} + +// streamingResult 流式响应结果 +type streamingResult struct { + usage *ClaudeUsage + firstTokenMs *int + clientDisconnect bool // 客户端是否在流式传输过程中断开 +} + +func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) { + // 更新5h窗口状态 + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + + // 设置SSE响应头 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + // 透传其他响应头 + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + scanner := bufio.NewScanner(resp.Body) + // 设置更大的buffer以处理长行 + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + // 仅监控上游数据间隔超时,避免下游写入阻塞导致误判 + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + // 下游 keepalive:防止代理/Cloudflare Tunnel 因连接空闲而断开 + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + + // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)。 + // 事件格式遵循 Anthropic SSE 标准:{"type":"error","error":{"type":,"message":}} + // 这样 Anthropic SDK / Claude Code 等客户端能按标准 error 类型解析,UI 能显示具体错误文案, + // 服务端 ExtractUpstreamErrorMessage 也能从透传的 body 中提取 message。 + errorEventSent := false + sendErrorEvent := func(reason, message string) { + if errorEventSent { + return + } + errorEventSent = true + if message == "" { + message = reason + } + body, err := json.Marshal(map[string]any{ + "type": "error", + "error": map[string]string{ + "type": reason, + "message": message, + }, + }) + if err != nil { + // json.Marshal 不可能在已知 string-only 输入上失败,保守 fallback + body = []byte(fmt.Sprintf(`{"type":"error","error":{"type":%q,"message":%q}}`, reason, message)) + } + _, _ = fmt.Fprintf(w, "event: error\ndata: %s\n\n", body) + flusher.Flush() + } + + needModelReplace := originalModel != mappedModel + clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage + sawTerminalEvent := false + + pendingEventLines := make([]string, 0, 4) + + processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) { + if len(lines) == 0 { + return nil, "", nil, nil + } + + eventName := "" + dataLine := "" + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "event:") { + eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) + continue + } + if dataLine == "" && sseDataRe.MatchString(trimmed) { + dataLine = sseDataRe.ReplaceAllString(trimmed, "") + } + } + + if eventName == "error" { + return nil, dataLine, nil, errors.New("have error in stream") + } + + if dataLine == "" { + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil + } + + if dataLine == "[DONE]" { + sawTerminalEvent = true + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil, nil + } + + var event map[string]any + if err := json.Unmarshal([]byte(dataLine), &event); err != nil { + // JSON 解析失败,直接透传原始数据 + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, nil, nil + } + + eventType, _ := event["type"].(string) + if eventName == "" { + eventName = eventType + } + eventChanged := false + + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + eventChanged = reconcileCachedTokens(u) || eventChanged + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + eventChanged = reconcileCachedTokens(u) || eventChanged + } + } + + // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { + if eventType == "message_start" { + if msg, ok := event["message"].(map[string]any); ok { + if u, ok := msg["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } + } + } + if eventType == "message_delta" { + if u, ok := event["usage"].(map[string]any); ok { + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged + } + } + } + + if needModelReplace { + if msg, ok := event["message"].(map[string]any); ok { + if model, ok := msg["model"].(string); ok && model == mappedModel { + msg["model"] = originalModel + eventChanged = true + } + } + } + + usagePatch := s.extractSSEUsagePatch(event) + if anthropicStreamEventIsTerminal(eventName, dataLine) { + sawTerminalEvent = true + } + if !eventChanged { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + + newData, err := json.Marshal(event) + if err != nil { + // 序列化失败,直接透传原始数据 + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + string(newData) + "\n\n" + return []string{block}, string(newData), usagePatch, nil + } + + for { + select { + case ev, ok := <-events: + if !ok { + // 上游完成,返回结果 + if !sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, fmt.Errorf("stream usage incomplete: missing terminal event") + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if sawTerminalEvent { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete: %w", ev.err) + } + // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after disconnect: %w", ev.err) + } + // 客户端未断开,正常的错误处理 + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + sendErrorEvent("response_too_large", fmt.Sprintf("upstream SSE line exceeded %d bytes", maxLineSize)) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + // 上游中途读错误(unexpected EOF / connection reset 等,常见于 HTTP/2 GOAWAY): + // 若尚未向客户端写过任何字节,包成 UpstreamFailoverError 让 handler 层走 failover/重试。 + // 已经开始写流时 SSE 协议无 resume,只能透传错误事件给客户端。 + // 注意:面向客户端的 disconnectMsg 必须用 sanitizeStreamError 剥离地址, + // 默认 *net.OpError 的 Error() 会泄露内部 IP/端口和上游地址。完整 ev.err + // 仅在下方 LegacyPrintf 内部日志中保留供运维诊断。 + disconnectMsg := "upstream stream disconnected: " + sanitizeStreamError(ev.err) + if !c.Writer.Written() { + logger.LegacyPrintf("service.gateway", "Upstream stream read error before any client output (account=%d), failing over: %v", account.ID, ev.err) + body, _ := json.Marshal(map[string]any{ + "type": "error", + "error": map[string]string{ + "type": "upstream_disconnected", + "message": disconnectMsg, + }, + }) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: body, + RetryableOnSameAccount: true, + } + } + sendErrorEvent("stream_read_error", disconnectMsg) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + line := ev.line + trimmed := strings.TrimSpace(line) + + if trimmed == "" { + if len(pendingEventLines) == 0 { + continue + } + + outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines) + pendingEventLines = pendingEventLines[:0] + if err != nil { + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + return nil, err + } + + for _, block := range outputBlocks { + if !clientDisconnected { + restored := reverseToolNamesIfPresent(c, []byte(block)) + if _, werr := fmt.Fprint(w, string(restored)); werr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + break + } + flusher.Flush() + lastDataAt = time.Now() + } + if data != "" { + if firstTokenMs == nil && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if usagePatch != nil { + mergeSSEUsagePatch(usage, usagePatch) + } + } + } + continue + } + + pendingEventLines = append(pendingEventLines, line) + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, fmt.Errorf("stream usage incomplete after timeout") + } + logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + // 处理流超时,可能标记账户为临时不可调度或错误状态 + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) + } + sendErrorEvent("stream_timeout", fmt.Sprintf("upstream stream idle for %s", streamInterval)) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if clientDisconnected { + continue + } + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + // SSE ping 事件:Anthropic 原生格式,客户端会正确处理, + // 同时保持连接活跃防止 Cloudflare Tunnel 等代理断开 + if _, werr := fmt.Fprint(w, "event: ping\ndata: {\"type\": \"ping\"}\n\n"); werr != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "Client disconnected during keepalive ping, continuing to drain upstream for billing") + continue + } + flusher.Flush() + } + } + +} + +func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { + if usage == nil { + return + } + + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return + } + + if patch := s.extractSSEUsagePatch(event); patch != nil { + mergeSSEUsagePatch(usage, patch) + } +} + +type sseUsagePatch struct { + inputTokens int + hasInputTokens bool + outputTokens int + hasOutputTokens bool + cacheCreationInputTokens int + hasCacheCreationInput bool + cacheReadInputTokens int + hasCacheReadInput bool + cacheCreation5mTokens int + hasCacheCreation5m bool + cacheCreation1hTokens int + hasCacheCreation1h bool +} + +func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch { + if len(event) == 0 { + return nil + } + + eventType, _ := event["type"].(string) + switch eventType { + case "message_start": + msg, _ := event["message"].(map[string]any) + usageObj, _ := msg["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + patch.hasInputTokens = true + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok { + patch.inputTokens = v + } + patch.hasCacheCreationInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok { + patch.cacheCreationInputTokens = v + } + patch.hasCacheReadInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok { + patch.cacheReadInputTokens = v + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + + case "message_delta": + usageObj, _ := event["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 { + patch.inputTokens = v + patch.hasInputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 { + patch.outputTokens = v + patch.hasOutputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 { + patch.cacheCreationInputTokens = v + patch.hasCacheCreationInput = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 { + patch.cacheReadInputTokens = v + patch.hasCacheReadInput = true + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + } + + return nil +} + +func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) { + if usage == nil || patch == nil { + return + } + + if patch.hasInputTokens { + usage.InputTokens = patch.inputTokens + } + if patch.hasCacheCreationInput { + usage.CacheCreationInputTokens = patch.cacheCreationInputTokens + } + if patch.hasCacheReadInput { + usage.CacheReadInputTokens = patch.cacheReadInputTokens + } + if patch.hasOutputTokens { + usage.OutputTokens = patch.outputTokens + } + if patch.hasCacheCreation5m { + usage.CacheCreation5mTokens = patch.cacheCreation5mTokens + } + if patch.hasCacheCreation1h { + usage.CacheCreation1hTokens = patch.cacheCreation1hTokens + } +} + +func parseSSEUsageInt(value any) (int, bool) { + switch v := value.(type) { + case float64: + return int(v), true + case float32: + return int(v), true + case int: + return v, true + case int64: + return int(v), true + case int32: + return int(v), true + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i), true + } + if f, err := v.Float64(); err == nil { + return int(f), true + } + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return parsed, true + } + } + return 0, false +} + +// applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 +// target 为 "5m" 或 "1h"。返回 true 表示发生了变更。 +func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { + // Fallback: 如果只有聚合字段但无 5m/1h 明细,将聚合字段归入 5m 默认类别 + if usage.CacheCreation5mTokens == 0 && usage.CacheCreation1hTokens == 0 && usage.CacheCreationInputTokens > 0 { + usage.CacheCreation5mTokens = usage.CacheCreationInputTokens + } + + total := usage.CacheCreation5mTokens + usage.CacheCreation1hTokens + if total == 0 { + return false + } + switch target { + case "1h": + if usage.CacheCreation1hTokens == total { + return false // 已经全是 1h + } + usage.CacheCreation1hTokens = total + usage.CacheCreation5mTokens = 0 + default: // "5m" + if usage.CacheCreation5mTokens == total { + return false // 已经全是 5m + } + usage.CacheCreation5mTokens = total + usage.CacheCreation1hTokens = 0 + } + return true +} + +// rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 +// usageObj 是 usage JSON 对象(map[string]any)。 +func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { + ccObj, ok := usageObj["cache_creation"].(map[string]any) + if !ok { + return false + } + v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"]) + v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"]) + total := v5m + v1h + if total == 0 { + return false + } + switch target { + case "1h": + if v1h == total { + return false + } + ccObj["ephemeral_1h_input_tokens"] = float64(total) + ccObj["ephemeral_5m_input_tokens"] = float64(0) + default: // "5m" + if v5m == total { + return false + } + ccObj["ephemeral_5m_input_tokens"] = float64(total) + ccObj["ephemeral_1h_input_tokens"] = float64(0) + } + return true +} + +func (s *GatewayService) resolveCacheTTLUsageOverrideTarget(ctx context.Context, account *Account) (string, bool) { + if account == nil { + return "", false + } + if account.IsCacheTTLOverrideEnabled() { + return account.GetCacheTTLOverrideTarget(), true + } + if account.IsAnthropicOAuthOrSetupToken() && s != nil && s.settingService != nil && s.settingService.IsAnthropicCacheTTL1hInjectionEnabled(ctx) { + return cacheTTLTarget5m, true + } + return "", false +} + +func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { + // 更新5h窗口状态 + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) + if err != nil { + return nil, err + } + + // 解析usage + var response struct { + Usage ClaudeUsage `json:"usage"` + } + if err := json.Unmarshal(body, &response); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 + cc5m := gjson.GetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens") + cc1h := gjson.GetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + response.Usage.CacheCreation5mTokens = int(cc5m.Int()) + response.Usage.CacheCreation1hTokens = int(cc1h.Int()) + } + + // 兼容 Kimi cached_tokens → cache_read_input_tokens + if response.Usage.CacheReadInputTokens == 0 { + cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() + if cachedTokens > 0 { + response.Usage.CacheReadInputTokens = int(cachedTokens) + if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { + body = newBody + } + } + } + + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类。 + // 账号级设置优先;全局 1h 请求注入开启时,默认把 usage 计费归回 5m。 + if overrideTarget, ok := s.resolveCacheTTLUsageOverrideTarget(ctx, account); ok { + if applyCacheTTLOverride(&response.Usage, overrideTarget) { + // 同步更新 body JSON 中的嵌套 cache_creation 对象 + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_5m_input_tokens", response.Usage.CacheCreation5mTokens); err == nil { + body = newBody + } + if newBody, err := sjson.SetBytes(body, "usage.cache_creation.ephemeral_1h_input_tokens", response.Usage.CacheCreation1hTokens); err == nil { + body = newBody + } + } + } + + // 如果有模型映射,替换响应中的model字段 + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := "application/json" + if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { + if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { + contentType = upstreamType + } + } + + body = reverseToolNamesIfPresent(c, body) + + // 写入响应 + c.Data(resp.StatusCode, contentType, body) + + return &response.Usage, nil +} + +// replaceModelInResponseBody 替换响应体中的model字段 +// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化 +func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody + } + return body +} + +// reconcileCachedTokens 兼容 Kimi 等上游: +// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens +func reconcileCachedTokens(usage map[string]any) bool { + if usage == nil { + return false + } + cacheRead, _ := usage["cache_read_input_tokens"].(float64) + if cacheRead > 0 { + return false // 已有标准字段,无需处理 + } + cached, _ := usage["cached_tokens"].(float64) + if cached <= 0 { + return false + } + usage["cache_read_input_tokens"] = cached + return true +} diff --git a/backend/internal/service/gateway_upstream.go b/backend/internal/service/gateway_upstream.go new file mode 100644 index 00000000000..4d189cbf510 --- /dev/null +++ b/backend/internal/service/gateway_upstream.go @@ -0,0 +1,540 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/tidwall/gjson" + + "github.com/gin-gonic/gin" +) + +// executeBedrockUpstream 执行 Bedrock 上游请求(含重试逻辑) +func (s *GatewayService) executeBedrockUpstream( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, + apiKey string, + proxyURL string, +) (*http.Response, error) { + var resp *http.Response + var err error + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + var upstreamReq *http.Request + if account.IsBedrockAPIKey() { + upstreamReq, err = s.buildUpstreamRequestBedrockAPIKey(ctx, body, modelID, region, stream, apiKey) + } else { + upstreamReq, err = s.buildUpstreamRequestBedrock(ctx, body, modelID, region, stream, signer) + } + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, nil) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "[Bedrock] account %d: upstream error %d, retry %d/%d after %v", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + return resp, nil +} + +// handleBedrockUpstreamErrors 处理 Bedrock 上游 4xx/5xx 错误(failover + 错误响应) +func (s *GatewayService) handleBedrockUpstreamErrors( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ForwardResult, error) { + // retry exhausted + failover + if s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Bedrock] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d Body=%s", + account.ID, account.Name, resp.StatusCode, truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + // non-retryable failover + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + }) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + + // other errors + return s.handleErrorResponse(ctx, resp, c, account) +} + +// buildUpstreamRequestBedrock 构建 Bedrock 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrock( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + signer *BedrockSigner, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + // SigV4 签名 + if err := signer.SignRequest(ctx, req, body); err != nil { + return nil, fmt.Errorf("sign bedrock request: %w", err) + } + + return req, nil +} + +// buildUpstreamRequestBedrockAPIKey 构建 Bedrock API Key (Bearer Token) 上游请求 +func (s *GatewayService) buildUpstreamRequestBedrockAPIKey( + ctx context.Context, + body []byte, + modelID string, + region string, + stream bool, + apiKey string, +) (*http.Request, error) { + targetURL := BuildBedrockURL(region, modelID, stream) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + return req, nil +} + +// handleBedrockNonStreamingResponse 处理 Bedrock 非流式响应 +// Bedrock InvokeModel 非流式响应的 body 格式与 Claude API 兼容 +func (s *GatewayService) handleBedrockNonStreamingResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, anthropicTooLargeError) + if err != nil { + return nil, err + } + + // 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式 + // 并移除该字段避免透传给客户端 + body = transformBedrockInvocationMetrics(body) + + usage := parseClaudeUsageFromResponseBody(body) + + c.Header("Content-Type", "application/json") + if v := resp.Header.Get("x-amzn-requestid"); v != "" { + c.Header("x-request-id", v) + } + c.Data(resp.StatusCode, "application/json", body) + return usage, nil +} + +func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { + if account.Platform == PlatformAnthropic && account.Type == AccountTypeServiceAccount { + return s.buildUpstreamRequestAnthropicVertex(ctx, c, account, body, token, modelID, reqStream) + } + + // 确定目标URL + targetURL := claudeAPIURL + if account.Type == AccountTypeAPIKey { + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages?beta=true" + } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) + } + + clientHeaders := http.Header{} + if c != nil && c.Request != nil { + clientHeaders = c.Request.Header + } + + // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) + var fingerprint *Fingerprint + enableFP, enableMPT, enableCCH := true, false, false + if s.settingService != nil { + enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx) + } + if account.IsOAuth() && s.identityService != nil { + // 1. 获取或创建指纹(包含随机生成的ClientID) + fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) + if err != nil { + logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) + // 失败时降级为透传原始headers + } else { + if enableFP { + fingerprint = fp + } + + // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 + // 当 metadata 透传开启时跳过重写 + if !enableMPT { + accountUUID := account.GetExtraString("account_uuid") + if accountUUID != "" && fp.ClientID != "" { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID, fp.UserAgent); err == nil && len(newBody) > 0 { + body = newBody + } + } + } + } + } + + // 同步 billing header cc_version 与实际发送的 User-Agent 版本 + if fingerprint != nil { + body = syncBillingHeaderVersion(body, fingerprint.UserAgent) + } + // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后) + if enableCCH { + body = signBillingHeaderCCH(body) + } + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 设置认证头(保持原始大小写) + if tokenType == "oauth" { + setHeaderRaw(req.Header, "authorization", "Bearer "+token) + } else { + setHeaderRaw(req.Header, "x-api-key", token) + } + + // 白名单透传 headers + // OAuth mimicry 路径:跳过客户端 header 透传,与 Parrot 对齐。 + // Parrot 的 build_upstream_headers 只发 9 个精确 header,不透传任何客户端 header。 + // 透传客户端 header 会引入不一致的 x-stainless-* / anthropic-beta / user-agent / + // x-claude-code-session-id 等值,和我们注入的伪装 header 冲突,被 Anthropic 判 third-party。 + if tokenType != "oauth" || !mimicClaudeCode { + for key, values := range clientHeaders { + lowerKey := strings.ToLower(key) + if allowedHeaders[lowerKey] { + wireKey := resolveWireCasing(key) + for _, v := range values { + addHeaderRaw(req.Header, wireKey, v) + } + } + } + } + + // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) + if fingerprint != nil { + s.identityService.ApplyFingerprint(req, fingerprint) + } + + // 确保必要的headers存在(保持原始大小写) + if getHeaderRaw(req.Header, "content-type") == "" { + setHeaderRaw(req.Header, "content-type", "application/json") + } + if getHeaderRaw(req.Header, "anthropic-version") == "" { + setHeaderRaw(req.Header, "anthropic-version", "2023-06-01") + } + if tokenType == "oauth" { + applyClaudeOAuthHeaderDefaults(req) + } + + // Build effective drop set: merge static defaults with dynamic beta policy filter rules + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) + effectiveDropSet := mergeDropSets(policyFilterSet) + + // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) + if tokenType == "oauth" { + if mimicClaudeCode { + // 非 Claude Code 客户端:按 opencode 的策略处理: + // - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app) + // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 + applyClaudeCodeMimicHeaders(req, reqStream) + + incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") + // Claude Code OAuth credentials are scoped to Claude Code. + // Non-haiku models MUST include claude-code beta for Anthropic to recognize + // this as a legitimate Claude Code request; without it, the request is + // rejected as third-party ("out of extra usage"). + // Haiku models are exempt from third-party detection and don't need it. + requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} + if !strings.Contains(strings.ToLower(modelID), "haiku") { + requiredBetas = claude.FullClaudeCodeMimicryBetas() + } + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) + } else { + // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta + clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), effectiveDropSet)) + } + } else { + // API-key accounts: apply beta policy filter to strip controlled tokens + if existingBeta := getHeaderRaw(req.Header, "anthropic-beta"); existingBeta != "" { + setHeaderRaw(req.Header, "anthropic-beta", stripBetaTokensWithSet(existingBeta, effectiveDropSet)) + } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey { + // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) + if requestNeedsBetaFeatures(body) { + if beta := defaultAPIKeyBetaHeader(body); beta != "" { + setHeaderRaw(req.Header, "anthropic-beta", beta) + } + } + } + } + + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + + // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === + s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ + "url": req.URL.String(), + "token_type": tokenType, + "mimic_claude_code": strconv.FormatBool(mimicClaudeCode), + "fingerprint_applied": strconv.FormatBool(fingerprint != nil), + "enable_fp": strconv.FormatBool(enableFP), + "enable_mpt": strconv.FormatBool(enableMPT), + }) + + // Always capture a compact fingerprint line for later error diagnostics. + // We only print it when needed (or when the explicit debug flag is enabled). + if c != nil && tokenType == "oauth" { + c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) + } + if s.debugClaudeMimicEnabled() { + logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) + } + + return req, nil +} + +func (s *GatewayService) buildUpstreamRequestAnthropicVertex( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, + modelID string, + reqStream bool, +) (*http.Request, error) { + vertexBody, err := buildVertexAnthropicRequestBody(body) + if err != nil { + return nil, err + } + setOpsUpstreamRequestBody(c, vertexBody) + fullURL, err := buildVertexAnthropicURL(account.VertexProjectID(), account.VertexLocation(modelID), modelID, reqStream) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(vertexBody)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] || lowerKey == "anthropic-version" { + continue + } + wireKey := resolveWireCasing(key) + for _, v := range values { + addHeaderRaw(req.Header, wireKey, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Del("anthropic-version") + setHeaderRaw(req.Header, "authorization", "Bearer "+token) + setHeaderRaw(req.Header, "content-type", "application/json") + + s.debugLogGatewaySnapshot("UPSTREAM_FORWARD_VERTEX_ANTHROPIC", req.Header, vertexBody, map[string]string{ + "url": req.URL.String(), + "token_type": "service_account", + "model": modelID, + "stream": strconv.FormatBool(reqStream), + }) + + return req, nil +} + +// buildCustomRelayURL 构建自定义中继转发 URL +// 在 path 后附加 beta=true 和可选的 proxy 查询参数 +func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { + u := strings.TrimRight(baseURL, "/") + path + "?beta=true" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL := account.Proxy.URL() + if proxyURL != "" { + u += "&proxy=" + url.QueryEscape(proxyURL) + } + } + return u +} + +func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { + if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { + normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil + } + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ + AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, + RequireAllowlist: true, + AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, + }) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil +} diff --git a/backend/internal/service/openai_account_selection_loadaware.go b/backend/internal/service/openai_account_selection_loadaware.go new file mode 100644 index 00000000000..3c653830f17 --- /dev/null +++ b/backend/internal/service/openai_account_selection_loadaware.go @@ -0,0 +1,668 @@ +package service + +import ( + "context" + "errors" + "fmt" + "log/slog" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// noAvailableOpenAISelectionError builds the standard "no account available" error +// while preserving the compact-specific error when applicable. +func noAvailableOpenAISelectionError(requestedModel string, compactBlocked bool) error { + if compactBlocked { + return ErrNoAvailableCompactAccounts + } + if requestedModel != "" { + return fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel) + } + return errors.New("no available OpenAI accounts") +} + +// openAICompactSupportTier classifies an OpenAI account by compact capability. +// 0 = explicitly unsupported, 1 = unknown / not yet probed, 2 = explicitly supported. +func openAICompactSupportTier(account *Account) int { + if account == nil || !account.IsOpenAI() { + return 0 + } + supported, known := account.OpenAICompactSupportKnown() + if !known { + return 1 + } + if supported { + return 2 + } + return 0 +} + +// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model / +// compact-support checks used during account selection. +func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool { + if account == nil || !account.IsSchedulable() || !account.IsOpenAI() { + return false + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return false + } + if requireCompact && openAICompactSupportTier(account) == 0 { + return false + } + return true +} + +// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known +// compact support are tried first, followed by unknown, then explicitly unsupported. +// The relative order within each tier is preserved. +func prioritizeOpenAICompactAccounts(accounts []*Account) []*Account { + if len(accounts) == 0 { + return nil + } + supported := make([]*Account, 0, len(accounts)) + unknown := make([]*Account, 0, len(accounts)) + unsupported := make([]*Account, 0, len(accounts)) + for _, account := range accounts { + switch openAICompactSupportTier(account) { + case 2: + supported = append(supported, account) + case 1: + unknown = append(unknown, account) + default: + unsupported = append(unsupported, account) + } + } + out := make([]*Account, 0, len(accounts)) + out = append(out, supported...) + out = append(out, unknown...) + out = append(out, unsupported...) + return out +} + +// resolveOpenAIAccountUpstreamModelForRequest resolves the upstream model that +// would be sent for a given request, honouring compact-only mappings when the +// caller is on the /responses/compact path. +func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedModel string, requireCompact bool) string { + upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "") + if upstreamModel == "" { + return "" + } + if requireCompact { + return resolveOpenAICompactForwardModel(account, upstreamModel) + } + return upstreamModel +} + +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + + // 1. 尝试粘性会话命中 + // Try sticky session hit + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil { + return account, nil + } + + // 2. 获取可调度的 OpenAI 账号 + // Get schedulable OpenAI accounts + accounts, err := s.listSchedulableAccounts(ctx, groupID) + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + + // 3. 按优先级 + LRU 选择最佳账号 + // Select by priority + LRU + selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact) + + if selected == nil { + return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked) + } + + // 4. 设置粘性会话绑定 + // Set sticky session binding + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) + } + + return s.hydrateSelectedAccount(ctx, selected) +} + +// tryStickySessionHit 尝试从粘性会话获取账号。 +// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。 +// +// tryStickySessionHit attempts to get account from sticky session. +// Returns account if hit and usable; clears session and returns nil if account is unavailable. +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account { + if sessionHash == "" { + return nil + } + + accountID := stickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash) + if err != nil || accountID <= 0 { + return nil + } + } + + if _, excluded := excludedIDs[accountID]; excluded { + return nil + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil { + return nil + } + + // 检查账号是否需要清理粘性会话 + // Check if sticky session should be cleared + if shouldClearStickySession(account, requestedModel) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } + + // 验证账号是否可用于当前请求 + // Verify account is usable for current request + if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) { + return nil + } + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } + if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) && + s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + return nil + } + + // 刷新会话 TTL 并返回账号 + // Refresh session TTL and return account + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) + return account +} + +// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。 +// 返回 nil 表示无可用账号。 +// +// selectBestAccount selects the best account from candidates (priority + LRU). +// Returns nil if no available account. The second return reports whether at +// least one candidate was filtered out solely because it lacks compact support +// (only meaningful when requireCompact=true). +func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) { + var selected *Account + selectedCompactTier := -1 + compactBlocked := false + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) + + for i := range accounts { + acc := &accounts[i] + + // 跳过被排除的账号 + // Skip excluded accounts + if _, excluded := excludedIDs[acc.ID]; excluded { + continue + } + + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + if fresh == nil { + continue + } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { + continue + } + compactTier := 0 + if requireCompact { + compactTier = openAICompactSupportTier(fresh) + if compactTier == 0 { + compactBlocked = true + continue + } + } + + // 选择优先级最高且最久未使用的账号 + // Select highest priority and least recently used + if selected == nil { + selected = fresh + selectedCompactTier = compactTier + continue + } + + // compact 模式下高 tier 优先;同 tier 内才比较 priority/LRU。 + if requireCompact && compactTier != selectedCompactTier { + if compactTier > selectedCompactTier { + selected = fresh + selectedCompactTier = compactTier + } + continue + } + + if s.isBetterAccount(fresh, selected) { + selected = fresh + selectedCompactTier = compactTier + } + } + + return selected, compactBlocked +} + +// isBetterAccount 判断 candidate 是否比 current 更优。 +// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。 +// +// isBetterAccount checks if candidate is better than current. +// Rules: higher priority (lower value) wins; same priority: never used > least recently used. +func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool { + // 优先级更高(数值更小) + // Higher priority (lower value) + if candidate.Priority < current.Priority { + return true + } + if candidate.Priority > current.Priority { + return false + } + + // 同优先级,比较最后使用时间 + // Same priority, compare last used time + switch { + case candidate.LastUsedAt == nil && current.LastUsedAt != nil: + // candidate 从未使用,优先 + return true + case candidate.LastUsedAt != nil && current.LastUsedAt == nil: + // current 从未使用,保持 + return false + case candidate.LastUsedAt == nil && current.LastUsedAt == nil: + // 都未使用,保持 + return false + default: + // 都使用过,选择最久未使用的 + return candidate.LastUsedAt.Before(*current.LastUsedAt) + } +} + +func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) { + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + slog.Warn("channel pricing restriction blocked request", + "group_id", derefGroupID(groupID), + "model", requestedModel) + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + + cfg := s.schedulingConfig() + needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID) + if err != nil { + return nil, err + } + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + } + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) + } + } + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) + } + + accounts, err := s.listSchedulableAccounts(ctx, groupID) + if err != nil { + return nil, err + } + if len(accounts) == 0 { + return nil, ErrNoAvailableAccounts + } + + isExcluded := func(accountID int64) bool { + if excludedIDs == nil { + return false + } + _, excluded := excludedIDs[accountID] + return excluded + } + + // ============ Layer 1: Sticky session ============ + if sessionHash != "" { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { + account, err := s.getSchedulableAccount(ctx, accountID) + if err == nil { + clearSticky := shouldClearStickySession(account, requestedModel) + if clearSticky { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } + if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) { + account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) + if account == nil { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) + } else { + result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if err == nil && result.Acquired { + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) + return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + } + + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) + if waitingCount < cfg.StickySessionMaxWaiting { + return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }) + } + } + } + } + } + } + + // ============ Layer 2: Load-aware selection ============ + baseCandidateCount := 0 + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + acc := &accounts[i] + if isExcluded(acc.ID) { + continue + } + // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); + // re-check schedulability here so recently rate-limited/overloaded accounts + // are not selected again before the bucket is rebuilt. + if !acc.IsSchedulable() { + continue + } + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) { + continue + } + baseCandidateCount++ + candidates = append(candidates, acc) + } + + if len(candidates) == 0 { + return nil, ErrNoAvailableAccounts + } + + accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) + for _, acc := range candidates { + accountLoads = append(accountLoads, AccountWithConcurrency{ + ID: acc.ID, + MaxConcurrency: acc.EffectiveLoadFactor(), + }) + } + + loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) + if err != nil { + ordered := append([]*Account(nil), candidates...) + sortAccountsByPriorityAndLastUsed(ordered, false) + if requireCompact { + ordered = prioritizeOpenAICompactAccounts(ordered) + } + for _, acc := range ordered { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + if fresh == nil { + continue + } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) + } + return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + } + } + } else { + var available []accountWithLoad + for _, acc := range candidates { + loadInfo := loadMap[acc.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: acc.ID} + } + if loadInfo.LoadRate < 100 { + available = append(available, accountWithLoad{ + account: acc, + loadInfo: loadInfo, + }) + } + } + + if len(available) > 0 { + sortAccountsWithLoadByPriority(available) + + selectionOrder := make([]accountWithLoad, 0, len(available)) + if requireCompact { + appendTier := func(out []accountWithLoad, tier int) []accountWithLoad { + for _, item := range available { + if openAICompactSupportTier(item.account) == tier { + out = append(out, item) + } + } + return out + } + selectionOrder = appendTier(selectionOrder, 2) + selectionOrder = appendTier(selectionOrder, 1) + // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际 + // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。 + selectionOrder = appendTier(selectionOrder, 0) + } else { + selectionOrder = append(selectionOrder, available...) + } + + for _, item := range selectionOrder { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) + if fresh == nil { + continue + } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { + continue + } + result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) + if err == nil && result.Acquired { + if sessionHash != "" { + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) + } + return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + } + } + } + } + + // ============ Layer 3: Fallback wait ============ + sortAccountsByPriorityAndLastUsed(candidates, false) + if requireCompact { + candidates = prioritizeOpenAICompactAccounts(candidates) + } + for _, acc := range candidates { + fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) + if fresh == nil { + continue + } + fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) + if fresh == nil { + continue + } + if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { + continue + } + return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{ + AccountID: fresh.ID, + MaxConcurrency: fresh.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }) + } + + if requireCompact && baseCandidateCount > 0 { + return nil, ErrNoAvailableCompactAccounts + } + return nil, ErrNoAvailableAccounts +} + +func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { + if s.schedulerSnapshot != nil { + accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false) + return accounts, err + } + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) + } else if groupID != nil { + accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) + } else { + accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI) + } + if err != nil { + return nil, fmt.Errorf("query accounts failed: %w", err) + } + return accounts, nil +} + +func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { + if s.concurrencyService == nil { + return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil + } + return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) +} + +func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { + if account == nil { + return nil + } + + fresh := account + if s.schedulerSnapshot != nil { + current, err := s.getSchedulableAccount(ctx, account.ID) + if err != nil || current == nil { + return nil + } + fresh = current + } + + if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) { + return nil + } + return fresh +} + +func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { + if account == nil { + return nil + } + if s.schedulerSnapshot == nil || s.accountRepo == nil { + if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) { + return nil + } + return account + } + + latest, err := s.accountRepo.GetByID(ctx, account.ID) + if err != nil || latest == nil { + return nil + } + if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) { + return nil + } + return latest +} + +func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { + var ( + account *Account + err error + ) + if s.schedulerSnapshot != nil { + account, err = s.schedulerSnapshot.GetAccount(ctx, accountID) + } else { + account, err = s.accountRepo.GetByID(ctx, accountID) + } + if err != nil || account == nil { + return account, err + } + return account, nil +} + +func (s *OpenAIGatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { + if account == nil || s.schedulerSnapshot == nil { + return account, nil + } + hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + if err != nil { + return nil, err + } + if hydrated == nil { + return nil, fmt.Errorf("selected openai account %d not found during hydration", account.ID) + } + return hydrated, nil +} + +func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { + hydrated, err := s.hydrateSelectedAccount(ctx, account) + if err != nil { + return nil, err + } + return &AccountSelectionResult{ + Account: hydrated, + Acquired: acquired, + ReleaseFunc: release, + WaitPlan: waitPlan, + }, nil +} + +func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { + if s.cfg != nil { + return s.cfg.Gateway.Scheduling + } + return config.GatewaySchedulingConfig{ + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: 45 * time.Second, + FallbackWaitTimeout: 30 * time.Second, + FallbackMaxWaiting: 100, + LoadBatchEnabled: true, + SlotCleanupInterval: 30 * time.Second, + } +} diff --git a/backend/internal/service/openai_codex.go b/backend/internal/service/openai_codex.go new file mode 100644 index 00000000000..266f8782770 --- /dev/null +++ b/backend/internal/service/openai_codex.go @@ -0,0 +1,336 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" +) + +// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers +type OpenAICodexUsageSnapshot struct { + PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` + PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"` + PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"` + SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"` + SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"` + SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"` + PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +// NormalizedCodexLimits contains normalized 5h/7d rate limit data +type NormalizedCodexLimits struct { + Used5hPercent *float64 + Reset5hSeconds *int + Window5hMinutes *int + Used7dPercent *float64 + Reset7dSeconds *int + Window7dMinutes *int +} + +func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { + if s != nil && s.codexDetector != nil { + return s.codexDetector + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAICodexClientRestrictionDetector(cfg) +} + +func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + return s.getCodexClientRestrictionDetector().Detect(c, account) +} + +func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { + if !result.Enabled { + return + } + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.Bool("codex_cli_only_enabled", result.Enabled), + zap.Bool("codex_official_client_match", result.Matched), + zap.String("reject_reason", result.Reason), + } + if apiKeyID > 0 { + fields = append(fields, zap.Int64("api_key_id", apiKeyID)) + } + if !result.Matched { + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + } + log := logger.FromContext(ctx).With(fields...) + if result.Matched { + return + } + log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") +} + +func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field { + if c == nil || c.Request == nil { + return fields + } + + req := c.Request + requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + fields = append(fields, + zap.String("request_method", strings.TrimSpace(req.Method)), + zap.String("request_path", strings.TrimSpace(req.URL.Path)), + zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)), + zap.String("request_host", strings.TrimSpace(req.Host)), + zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())), + zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)), + zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))), + zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))), + zap.Int64("request_content_length", req.ContentLength), + zap.Bool("request_stream", requestStream), + ) + if requestModel != "" { + fields = append(fields, zap.String("request_model", requestModel)) + } + if promptCacheKey != "" { + fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey))) + } + + if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 { + fields = append(fields, zap.Any("request_headers", headers)) + } + fields = append(fields, zap.Int("request_body_size", len(body))) + return fields +} + +func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string { + if len(header) == 0 { + return nil + } + result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist)) + for _, key := range codexCLIOnlyDebugHeaderWhitelist { + value := strings.TrimSpace(header.Get(key)) + if value == "" { + continue + } + result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes) + } + return result +} + +func hashSensitiveValueForLog(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + sum := sha256.Sum256([]byte(value)) + return hex.EncodeToString(sum[:8]) +} + +func logOpenAIInstructionsRequiredDebug( + ctx context.Context, + c *gin.Context, + account *Account, + upstreamStatusCode int, + upstreamMsg string, + requestBody []byte, + upstreamBody []byte, +) { + msg := strings.TrimSpace(upstreamMsg) + if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) { + return + } + if ctx == nil { + ctx = context.Background() + } + + accountID := int64(0) + accountName := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + } + + userAgent := "" + originator := "" + if c != nil { + userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + originator = strings.TrimSpace(c.GetHeader("originator")) + } + + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.Int("upstream_status_code", upstreamStatusCode), + zap.String("upstream_error_message", msg), + zap.String("request_user_agent", userAgent), + zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) + + logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查") +} + +func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + hasInstructionRequired := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "instructions are required") { + return true + } + if strings.Contains(lower, "required parameter: 'instructions'") { + return true + } + if strings.Contains(lower, "required parameter: instructions") { + return true + } + if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") { + return true + } + return strings.Contains(lower, "instruction") && strings.Contains(lower, "required") + } + + if hasInstructionRequired(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + + errMsg := gjson.GetBytes(upstreamBody, "error.message").String() + errMsgLower := strings.ToLower(strings.TrimSpace(errMsg)) + errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String())) + errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String())) + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String())) + + if errParam == "instructions" { + return true + } + if hasInstructionRequired(errMsg) { + return true + } + if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") { + return true + } + if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") { + return true + } + + return false +} + +func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + match := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "an error occurred while processing your request") { + return true + } + return strings.Contains(lower, "you can retry your request") && + strings.Contains(lower, "help.openai.com") && + strings.Contains(lower, "request id") + } + + if match(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + if match(gjson.GetBytes(upstreamBody, "error.message").String()) { + return true + } + return match(string(upstreamBody)) +} + +func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { + if reqBody == nil { + return "", false + } + + // Primary: reasoning.effort + if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { + if effort, ok := reasoning["effort"].(string); ok { + return normalizeOpenAIReasoningEffort(effort), true + } + } + + // Fallback: some clients may use a flat field. + if effort, ok := reqBody["reasoning_effort"].(string); ok { + return normalizeOpenAIReasoningEffort(effort), true + } + + return "", false +} + +func deriveOpenAIReasoningEffortFromModel(model string) string { + if strings.TrimSpace(model) == "" { + return "" + } + + modelID := strings.TrimSpace(model) + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + + parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool { + switch r { + case '-', '_', ' ': + return true + default: + return false + } + }) + if len(parts) == 0 { + return "" + } + + return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) +} + +func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { + reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if reasoningEffort == "" { + reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if reasoningEffort != "" { + normalized := normalizeOpenAIReasoningEffort(reasoningEffort) + if normalized == "" { + return nil + } + return &normalized + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} diff --git a/backend/internal/service/openai_config_methods.go b/backend/internal/service/openai_config_methods.go new file mode 100644 index 00000000000..909203802d5 --- /dev/null +++ b/backend/internal/service/openai_config_methods.go @@ -0,0 +1,86 @@ +package service + +import ( + "math/rand" + "time" +) + +func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { + if s != nil && s.codexSnapshotThrottle != nil { + return s.codexSnapshotThrottle + } + return defaultOpenAICodexSnapshotPersistThrottle +} + +func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + + initial := openAIWSRetryBackoffInitialDefault + maxBackoff := openAIWSRetryBackoffMaxDefault + jitterRatio := openAIWSRetryJitterRatioDefault + if s != nil && s.cfg != nil { + wsCfg := s.cfg.Gateway.OpenAIWS + if wsCfg.RetryBackoffInitialMS > 0 { + initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond + } + if wsCfg.RetryBackoffMaxMS > 0 { + maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond + } + if wsCfg.RetryJitterRatio >= 0 { + jitterRatio = wsCfg.RetryJitterRatio + } + } + if initial <= 0 { + return 0 + } + if maxBackoff <= 0 { + maxBackoff = initial + } + if maxBackoff < initial { + maxBackoff = initial + } + if jitterRatio < 0 { + jitterRatio = 0 + } + if jitterRatio > 1 { + jitterRatio = 1 + } + + shift := attempt - 1 + if shift < 0 { + shift = 0 + } + backoff := initial + if shift > 0 { + backoff = initial * time.Duration(1< maxBackoff { + backoff = maxBackoff + } + if jitterRatio <= 0 { + return backoff + } + jitter := time.Duration(float64(backoff) * jitterRatio) + if jitter <= 0 { + return backoff + } + delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter + withJitter := backoff + delta + if withJitter < 0 { + return 0 + } + return withJitter +} + +func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { + if s != nil && s.cfg != nil { + ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS + if ms <= 0 { + return 0 + } + return time.Duration(ms) * time.Millisecond + } + return 0 +} diff --git a/backend/internal/service/openai_forward_passthrough.go b/backend/internal/service/openai_forward_passthrough.go new file mode 100644 index 00000000000..46e56bead5f --- /dev/null +++ b/backend/internal/service/openai_forward_passthrough.go @@ -0,0 +1,1006 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +func (s *OpenAIGatewayService) forwardOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reasoningEffort *string, + reqStream bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + upstreamPassthroughModel := "" + if isOpenAIResponsesCompactPath(c) { + compactMappedModel := resolveOpenAICompactForwardModel(account, reqModel) + if compactMappedModel != "" && compactMappedModel != reqModel { + nextBody, setErr := sjson.SetBytes(body, "model", compactMappedModel) + if setErr != nil { + return nil, fmt.Errorf("set compact passthrough model: %w", setErr) + } + body = nextBody + upstreamPassthroughModel = compactMappedModel + } + } + + if account != nil && account.Type == AccountTypeOAuth { + if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { + rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" + setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusForbidden, + Passthrough: true, + Kind: "request_error", + Message: rejectMsg, + Detail: rejectReason, + }) + logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": rejectMsg, + }, + }) + return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) + } + + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body, isOpenAIResponsesCompactPath(c)) + if err != nil { + return nil, err + } + if normalized { + body = normalizedBody + } + reqStream = gjson.GetBytes(body, "stream").Bool() + } + + sanitizedBody, sanitized, err := sanitizeEmptyBase64InputImagesInOpenAIBody(body) + if err != nil { + return nil, err + } + if sanitized { + body = sanitizedBody + } + + // Apply OpenAI fast policy to the passthrough body (filter/block by service_tier). + // 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 + + // OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。 + // 这样可以与 chat-completions / messages / native /responses 入口的 + // upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有 + // model 字段时退回 reqModel。 + policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String()) + if policyModel == "" { + policyModel = reqModel + } + updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body) + if policyErr != nil { + var blocked *OpenAIFastBlockedError + if errors.As(policyErr, &blocked) { + writeOpenAIFastPolicyBlockedResponse(c, blocked) + } + return nil, policyErr + } + body = updatedBody + + apiKey := getAPIKeyFromContext(c) + if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") + } + imageBillingModel := "" + imageSizeTier := "" + if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel) + if imageCfgErr != nil { + setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": imageCfgErr.Error(), + "param": "size", + }, + }) + return nil, imageCfgErr + } + } + + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", + account.ID, + account.Name, + account.Type, + reqModel, + reqStream, + ) + if reqStream && c != nil && c.Request != nil { + if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 { + streamWarnLogger := logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.Strings("timeout_headers", timeoutHeaders), + ) + if s.isOpenAIPassthroughTimeoutHeadersAllowed() { + streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流") + } else { + streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险") + } + } + } + + // Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) + releaseUpstreamCtx() + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + setOpsUpstreamRequestBody(c, body) + if c != nil { + c.Set("openai_passthrough", true) + } + + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + // 透传模式默认保持原样代理;但 429/529 属于网关必须兜底的 + // 上游容量类错误,应先触发多账号 failover 以维持基础 SLA。 + if shouldFailoverOpenAIPassthroughResponse(resp.StatusCode) { + return nil, s.handleFailoverErrorResponsePassthrough(ctx, resp, c, account, body) + } + return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) + } + + var usage *OpenAIUsage + var firstTokenMs *int + imageCount := 0 + if reqStream { + result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel) + if err != nil { + return nil, err + } + usage = result.usage + firstTokenMs = result.firstTokenMs + imageCount = result.imageCount + } else { + result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) + if err != nil { + return nil, err + } + usage = result.usage + imageCount = result.imageCount + } + + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + + if usage == nil { + usage = &OpenAIUsage{} + } + + forwardResult := &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + UpstreamModel: upstreamPassthroughModel, + ServiceTier: extractOpenAIServiceTierFromBody(body), + ReasoningEffort: reasoningEffort, + Stream: reqStream, + OpenAIWSMode: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + } + if imageCount > 0 { + forwardResult.ImageCount = imageCount + forwardResult.ImageSize = imageSizeTier + forwardResult.BillingModel = imageBillingModel + } + return forwardResult, nil +} + +func logOpenAIPassthroughInstructionsRejected( + ctx context.Context, + c *gin.Context, + account *Account, + reqModel string, + rejectReason string, + body []byte, +) { + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + accountName := "" + accountType := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + accountType = strings.TrimSpace(string(account.Type)) + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.String("account_type", accountType), + zap.String("request_model", strings.TrimSpace(reqModel)), + zap.String("reject_reason", strings.TrimSpace(rejectReason)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions") +} + +func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := openaiPlatformAPIURL + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + } + targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 透传客户端请求头(安全白名单)。 + allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lower := strings.ToLower(strings.TrimSpace(key)) + if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Set("authorization", "Bearer "+token) + + // OAuth 透传到 ChatGPT internal API 时补齐必要头。 + if account.Type == AccountTypeOAuth { + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + req.Host = "chatgpt.com" + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + apiKeyID := getAPIKeyIDFromContext(c) + // 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。 + clientSessionID := strings.TrimSpace(req.Header.Get("session_id")) + clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id")) + if isOpenAIResponsesCompactPath(c) { + req.Header.Set("accept", "application/json") + if req.Header.Get("version") == "" { + req.Header.Set("version", codexCLIVersion) + } + if clientSessionID == "" { + clientSessionID = resolveOpenAICompactSessionID(c) + } + } else if req.Header.Get("accept") == "" { + req.Header.Set("accept", "text/event-stream") + } + if req.Header.Get("OpenAI-Beta") == "" { + req.Header.Set("OpenAI-Beta", "responses=experimental") + } + if req.Header.Get("originator") == "" { + req.Header.Set("originator", "codex_cli_rs") + } + // 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。 + if clientSessionID == "" { + clientSessionID = promptCacheKey + } + if clientConversationID == "" { + clientConversationID = promptCacheKey + } + if clientSessionID != "" { + req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, clientSessionID)) + } + if clientConversationID != "" { + req.Header.Set("conversation_id", isolateOpenAISessionID(apiKeyID, clientConversationID)) + } + } + + // 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。 + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。 + if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) { + req.Header.Set("user-agent", codexCLIUserAgent) + } + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + + return req, nil +} + +func shouldFailoverOpenAIPassthroughResponse(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, 529: + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + if s.rateLimitService != nil { + _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + return &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: body, + ResponseHeaders: resp.Header.Clone(), + } +} + +func (s *OpenAIGatewayService) handleErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + if s.rateLimitService != nil { + // Passthrough mode preserves the raw upstream error response, but runtime + // account state still needs to be updated so sticky routing can stop + // reusing a freshly rate-limited account. + _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { + if lowerKey == "" { + return false + } + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + return allowTimeoutHeaders + } + return openaiPassthroughAllowedHeaders[lowerKey] +} + +func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { + switch lowerKey { + case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout": + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders +} + +func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { + if h == nil { + return nil + } + var matched []string + for key, values := range h { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + entry := lowerKey + if len(values) > 0 { + entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|")) + } + matched = append(matched, entry) + } + } + sort.Strings(matched) + return matched +} + +type openaiStreamingResultPassthrough struct { + usage *OpenAIUsage + firstTokenMs *int + imageCount int +} + +type openaiNonStreamingResultPassthrough struct { + *OpenAIUsage + usage *OpenAIUsage + imageCount int +} + +func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { + if localStarted { + return true + } + return c != nil && c.Writer != nil && c.Writer.Written() +} + +func openAIStreamEventIsPreamble(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.created", "response.in_progress": + return true + default: + return false + } +} + +func openAIStreamDataStartsClientOutput(data, eventType string) bool { + trimmed := strings.TrimSpace(data) + if trimmed == "" { + return false + } + if strings.TrimSpace(eventType) == "response.failed" { + return false + } + return !openAIStreamEventIsPreamble(eventType) +} + +func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool { + code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String())) + if code == "" { + code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String())) + } + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String())) + if errType == "" { + errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String())) + } + combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType)) + if combined == "" { + return true + } + nonRetryableMarkers := []string{ + "invalid_request", + "content_policy", + "policy", + "safety", + "high-risk cyber", + "not allowed", + "violat", + } + for _, marker := range nonRetryableMarkers { + if strings.Contains(combined, marker) { + return false + } + } + return true +} + +func (s *OpenAIGatewayService) newOpenAIStreamFailoverError( + c *gin.Context, + account *Account, + passthrough bool, + upstreamRequestID string, + payload []byte, + message string, +) *UpstreamFailoverError { + message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) + if message == "" { + message = "OpenAI stream disconnected before completion" + } + detail := "" + if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + detail = truncateString(string(payload), maxBytes) + } + if c != nil { + setOpsUpstreamError(c, http.StatusBadGateway, message, detail) + event := OpsUpstreamErrorEvent{ + Platform: PlatformOpenAI, + UpstreamStatusCode: http.StatusBadGateway, + UpstreamRequestID: strings.TrimSpace(upstreamRequestID), + Passthrough: passthrough, + Kind: "failover", + Message: message, + Detail: detail, + } + if account != nil { + event.Platform = account.Platform + event.AccountID = account.ID + event.AccountName = account.Name + } + appendOpsUpstreamError(c, event) + } + body, _ := json.Marshal(gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": message, + }, + }) + return &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: body, + } +} + +func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + originalModel string, + mappedModel string, +) (*openaiStreamingResultPassthrough, error) { + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + // SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &OpenAIUsage{} + imageCounter := newOpenAIImageOutputCounter() + var firstTokenMs *int + clientDisconnected := false + sawDone := false + sawTerminalEvent := false + sawFailedEvent := false + failedMessage := "" + clientOutputStarted := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + pendingLines := make([]string, 0, 8) + writePendingLines := func() bool { + for _, pending := range pendingLines { + if _, err := fmt.Fprintln(w, pending); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + return false + } + } + pendingLines = pendingLines[:0] + return true + } + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + defer putSSEScannerBuf64K(scanBuf) + + needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel) + resultWithUsage := func() *openaiStreamingResultPassthrough { + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} + } + + for scanner.Scan() { + line := scanner.Text() + lineStartsClientOutput := false + forceFlushFailedEvent := false + if data, ok := extractOpenAISSEDataLine(line); ok { + dataBytes := []byte(data) + trimmedData := strings.TrimSpace(data) + if needModelReplace && strings.Contains(data, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + if replacedData, replaced := extractOpenAISSEDataLine(line); replaced { + dataBytes = []byte(replacedData) + trimmedData = strings.TrimSpace(replacedData) + } + } + eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String()) + if eventType == "response.failed" { + failedMessage = extractOpenAISSEErrorMessage(dataBytes) + if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { + return resultWithUsage(), + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage) + } + forceFlushFailedEvent = true + sawFailedEvent = true + } + if trimmedData == "[DONE]" { + sawDone = true + } + if openAIStreamEventIsTerminal(trimmedData) { + sawTerminalEvent = true + } + imageCounter.AddSSEData(dataBytes) + lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType) + if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + } + + if !clientDisconnected { + if !clientOutputStarted && !lineStartsClientOutput { + pendingLines = append(pendingLines, line) + continue + } + if !clientOutputStarted && len(pendingLines) > 0 { + if !writePendingLines() { + continue + } + } + if _, err := fmt.Fprintln(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else { + clientOutputStarted = true + flusher.Flush() + } + } + } + if err := scanner.Err(); err != nil { + if sawTerminalEvent && !sawFailedEvent { + return resultWithUsage(), nil + } + if sawFailedEvent { + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) + } + if errors.Is(err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) + return resultWithUsage(), err + } + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + msg := "OpenAI stream disconnected before completion" + if errText := strings.TrimSpace(err.Error()); errText != "" { + msg += ": " + errText + } + return resultWithUsage(), + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg) + } + if clientDisconnected { + return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err) + } + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", + account.ID, + upstreamRequestID, + err, + ) + return resultWithUsage(), fmt.Errorf("stream read error: %w", err) + } + if sawFailedEvent { + return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) + } + if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { + logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.String("upstream_request_id", upstreamRequestID), + ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + if !openAIStreamClientOutputStarted(c, clientOutputStarted) { + return resultWithUsage(), + s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event") + } + return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event") + } + + return resultWithUsage(), nil +} + +func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + originalModel string, + mappedModel string, +) (*openaiNonStreamingResultPassthrough, error) { + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return nil, err + } + + // Detect SSE responses from upstream and convert to JSON. + // Some upstreams (e.g. other sub2api instances) may return SSE even when + // stream=false was requested. Without this conversion the client would + // receive raw SSE text or a terminal event with empty output. + if isEventStreamResponse(resp.Header) { + return s.handlePassthroughSSEToJSON(resp, c, body, originalModel, mappedModel) + } + + usage := &OpenAIUsage{} + usageParsed := false + if len(body) > 0 { + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok { + *usage = parsedUsage + usageParsed = true + } + } + if !usageParsed { + // 兜底:尝试从 SSE 文本中解析 usage + usage = s.parseSSEUsageFromBody(string(body)) + } + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + if originalModel != "" && mappedModel != "" && originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + c.Data(resp.StatusCode, contentType, body) + return &openaiNonStreamingResultPassthrough{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + }, nil +} + +// handlePassthroughSSEToJSON converts an SSE response body into a JSON +// response for the passthrough path. It mirrors handleSSEToJSON while +// preserving passthrough payloads, except compact-only model remapping may +// rewrite model fields back to the original requested model. +func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) { + bodyText := string(body) + finalResponse, ok := extractCodexFinalResponse(bodyText) + + usage := &OpenAIUsage{} + if ok { + if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { + *usage = parsedUsage + } + // When the terminal event has an empty output array, reconstruct + // output from accumulated delta events so the client gets full content. + if len(gjson.GetBytes(finalResponse, "output").Array()) == 0 { + if outputJSON, reconstructed := reconstructResponseOutputFromSSE(bodyText); reconstructed { + if patched, err := sjson.SetRawBytes(finalResponse, "output", outputJSON); err == nil { + finalResponse = patched + } + } + } + body = finalResponse + if originalModel != "" && mappedModel != "" && originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + // Correct tool calls in final response + body = s.correctToolCallsInResponseBody(body) + } else { + terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText) + if terminalOK && terminalType == "response.failed" { + msg := extractOpenAISSEErrorMessage(terminalPayload) + if msg == "" { + msg = "Upstream compact response failed" + } + return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) + } + usage = s.parseSSEUsageFromBody(bodyText) + if originalModel != "" && mappedModel != "" && originalModel != mappedModel { + bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) + } + body = []byte(bodyText) + } + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + + contentType := "application/json; charset=utf-8" + if !ok { + contentType = resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream" + } + } + c.Data(resp.StatusCode, contentType, body) + + return &openaiNonStreamingResultPassthrough{ + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + }, nil +} + +func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { + if dst == nil || src == nil { + return + } + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) + } else { + // 兜底:尽量保留最基础的 content-type + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + } + // 透传模式强制放行 x-codex-* 响应头(若上游返回)。 + // 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应, + // 这里用 EqualFold 做一次大小写不敏感的查找。 + getCaseInsensitiveValues := func(h http.Header, want string) []string { + if h == nil { + return nil + } + for k, vals := range h { + if strings.EqualFold(k, want) { + return vals + } + } + return nil + } + + for _, rawKey := range []string{ + "x-codex-primary-used-percent", + "x-codex-primary-reset-after-seconds", + "x-codex-primary-window-minutes", + "x-codex-secondary-used-percent", + "x-codex-secondary-reset-after-seconds", + "x-codex-secondary-window-minutes", + "x-codex-primary-over-secondary-limit-percent", + } { + vals := getCaseInsensitiveValues(src, rawKey) + if len(vals) == 0 { + continue + } + key := http.CanonicalHeaderKey(rawKey) + dst.Del(key) + for _, v := range vals { + dst.Add(key, v) + } + } +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index e12b208e372..b09ba61b359 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4,16 +4,12 @@ import ( "bufio" "bytes" "context" - "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" "io" "log/slog" - "math/rand" "net/http" - "sort" "strconv" "strings" "sync" @@ -26,7 +22,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" - "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/tidwall/gjson" @@ -102,28 +97,6 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{ "X-Real-IP", } -// OpenAICodexUsageSnapshot represents Codex API usage limits from response headers -type OpenAICodexUsageSnapshot struct { - PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` - PrimaryResetAfterSeconds *int `json:"primary_reset_after_seconds,omitempty"` - PrimaryWindowMinutes *int `json:"primary_window_minutes,omitempty"` - SecondaryUsedPercent *float64 `json:"secondary_used_percent,omitempty"` - SecondaryResetAfterSeconds *int `json:"secondary_reset_after_seconds,omitempty"` - SecondaryWindowMinutes *int `json:"secondary_window_minutes,omitempty"` - PrimaryOverSecondaryPercent *float64 `json:"primary_over_secondary_percent,omitempty"` - UpdatedAt string `json:"updated_at,omitempty"` -} - -// NormalizedCodexLimits contains normalized 5h/7d rate limit data -type NormalizedCodexLimits struct { - Used5hPercent *float64 - Reset5hSeconds *int - Window5hMinutes *int - Used7dPercent *float64 - Reset7dSeconds *int - Window7dMinutes *int -} - // Normalize converts primary/secondary fields to canonical 5h/7d fields. // Strategy: Compare window_minutes to determine which is 5h vs 7d. // Returns nil if snapshot is nil or has no useful data. @@ -209,35 +182,6 @@ type OpenAIUsage struct { ImageOutputTokens int `json:"image_output_tokens,omitempty"` } -// OpenAIForwardResult represents the result of forwarding -type OpenAIForwardResult struct { - RequestID string - ResponseID string - Usage OpenAIUsage - Model string // 原始模型(用于响应和日志显示) - // BillingModel is the model used for cost calculation. - // When non-empty, CalculateCost uses this instead of Model. - // This is set by the Anthropic Messages conversion path where - // the mapped upstream model differs from the client-facing model. - BillingModel string - // UpstreamModel is the actual model sent to the upstream provider after mapping. - // Empty when no mapping was applied (requested model was used as-is). - UpstreamModel string - // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". - // Nil means the request did not specify a recognized tier. - ServiceTier *string - // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. - // Stored for usage records display; nil means not provided / not applicable. - ReasoningEffort *string - Stream bool - OpenAIWSMode bool - ResponseHeaders http.Header - Duration time.Duration - FirstTokenMs *int - ImageCount int - ImageSize string -} - type OpenAIWSRetryMetricsSnapshot struct { RetryAttemptsTotal int64 `json:"retry_attempts_total"` RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"` @@ -498,13 +442,6 @@ func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) return ReplaceModelInBody(body, newModel) } -func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { - if s != nil && s.codexSnapshotThrottle != nil { - return s.codexSnapshotThrottle - } - return defaultOpenAICodexSnapshotPersistThrottle -} - func (s *OpenAIGatewayService) billingDeps() *billingDeps { return &billingDeps{ accountRepo: s.accountRepo, @@ -549,17 +486,6 @@ func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() { ) } -func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { - if s != nil && s.codexDetector != nil { - return s.codexDetector - } - var cfg *config.Config - if s != nil { - cfg = s.cfg - } - return NewOpenAICodexClientRestrictionDetector(cfg) -} - func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver { if s != nil && s.openaiWSResolver != nil { return s.openaiWSResolver @@ -745,79 +671,6 @@ func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context return true } -func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { - if attempt <= 0 { - return 0 - } - - initial := openAIWSRetryBackoffInitialDefault - maxBackoff := openAIWSRetryBackoffMaxDefault - jitterRatio := openAIWSRetryJitterRatioDefault - if s != nil && s.cfg != nil { - wsCfg := s.cfg.Gateway.OpenAIWS - if wsCfg.RetryBackoffInitialMS > 0 { - initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond - } - if wsCfg.RetryBackoffMaxMS > 0 { - maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond - } - if wsCfg.RetryJitterRatio >= 0 { - jitterRatio = wsCfg.RetryJitterRatio - } - } - if initial <= 0 { - return 0 - } - if maxBackoff <= 0 { - maxBackoff = initial - } - if maxBackoff < initial { - maxBackoff = initial - } - if jitterRatio < 0 { - jitterRatio = 0 - } - if jitterRatio > 1 { - jitterRatio = 1 - } - - shift := attempt - 1 - if shift < 0 { - shift = 0 - } - backoff := initial - if shift > 0 { - backoff = initial * time.Duration(1< maxBackoff { - backoff = maxBackoff - } - if jitterRatio <= 0 { - return backoff - } - jitter := time.Duration(float64(backoff) * jitterRatio) - if jitter <= 0 { - return backoff - } - delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter - withJitter := backoff + delta - if withJitter < 0 { - return 0 - } - return withJitter -} - -func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { - if s != nil && s.cfg != nil { - ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS - if ms <= 0 { - return 0 - } - return time.Duration(ms) * time.Millisecond - } - return 0 -} - func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) { if s == nil { return @@ -880,10 +733,6 @@ func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMet } } -func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { - return s.getCodexClientRestrictionDetector().Detect(c, account) -} - func getAPIKeyIDFromContext(c *gin.Context) int64 { if c == nil { return 0 @@ -899,338 +748,6 @@ func getAPIKeyIDFromContext(c *gin.Context) int64 { return apiKey.ID } -// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符, -// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id, -// 到达上游的标识符也不同,防止跨用户会话碰撞。 -func isolateOpenAISessionID(apiKeyID int64, raw string) string { - raw = strings.TrimSpace(raw) - if raw == "" { - return "" - } - h := xxhash.New() - _, _ = fmt.Fprintf(h, "k%d:", apiKeyID) - _, _ = h.WriteString(raw) - return fmt.Sprintf("%016x", h.Sum64()) -} - -func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { - if !result.Enabled { - return - } - if ctx == nil { - ctx = context.Background() - } - accountID := int64(0) - if account != nil { - accountID = account.ID - } - fields := []zap.Field{ - zap.String("component", "service.openai_gateway"), - zap.Int64("account_id", accountID), - zap.Bool("codex_cli_only_enabled", result.Enabled), - zap.Bool("codex_official_client_match", result.Matched), - zap.String("reject_reason", result.Reason), - } - if apiKeyID > 0 { - fields = append(fields, zap.Int64("api_key_id", apiKeyID)) - } - if !result.Matched { - fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) - } - log := logger.FromContext(ctx).With(fields...) - if result.Matched { - return - } - log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") -} - -func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field { - if c == nil || c.Request == nil { - return fields - } - - req := c.Request - requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) - fields = append(fields, - zap.String("request_method", strings.TrimSpace(req.Method)), - zap.String("request_path", strings.TrimSpace(req.URL.Path)), - zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)), - zap.String("request_host", strings.TrimSpace(req.Host)), - zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())), - zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)), - zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))), - zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))), - zap.Int64("request_content_length", req.ContentLength), - zap.Bool("request_stream", requestStream), - ) - if requestModel != "" { - fields = append(fields, zap.String("request_model", requestModel)) - } - if promptCacheKey != "" { - fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey))) - } - - if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 { - fields = append(fields, zap.Any("request_headers", headers)) - } - fields = append(fields, zap.Int("request_body_size", len(body))) - return fields -} - -func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string { - if len(header) == 0 { - return nil - } - result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist)) - for _, key := range codexCLIOnlyDebugHeaderWhitelist { - value := strings.TrimSpace(header.Get(key)) - if value == "" { - continue - } - result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes) - } - return result -} - -func hashSensitiveValueForLog(raw string) string { - value := strings.TrimSpace(raw) - if value == "" { - return "" - } - sum := sha256.Sum256([]byte(value)) - return hex.EncodeToString(sum[:8]) -} - -func logOpenAIInstructionsRequiredDebug( - ctx context.Context, - c *gin.Context, - account *Account, - upstreamStatusCode int, - upstreamMsg string, - requestBody []byte, - upstreamBody []byte, -) { - msg := strings.TrimSpace(upstreamMsg) - if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) { - return - } - if ctx == nil { - ctx = context.Background() - } - - accountID := int64(0) - accountName := "" - if account != nil { - accountID = account.ID - accountName = strings.TrimSpace(account.Name) - } - - userAgent := "" - originator := "" - if c != nil { - userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) - originator = strings.TrimSpace(c.GetHeader("originator")) - } - - fields := []zap.Field{ - zap.String("component", "service.openai_gateway"), - zap.Int64("account_id", accountID), - zap.String("account_name", accountName), - zap.Int("upstream_status_code", upstreamStatusCode), - zap.String("upstream_error_message", msg), - zap.String("request_user_agent", userAgent), - zap.Bool("codex_official_client_match", openai.IsCodexOfficialClientByHeaders(userAgent, originator)), - } - fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) - - logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查") -} - -func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { - if upstreamStatusCode != http.StatusBadRequest { - return false - } - - hasInstructionRequired := func(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - if strings.Contains(lower, "instructions are required") { - return true - } - if strings.Contains(lower, "required parameter: 'instructions'") { - return true - } - if strings.Contains(lower, "required parameter: instructions") { - return true - } - if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") { - return true - } - return strings.Contains(lower, "instruction") && strings.Contains(lower, "required") - } - - if hasInstructionRequired(upstreamMsg) { - return true - } - if len(upstreamBody) == 0 { - return false - } - - errMsg := gjson.GetBytes(upstreamBody, "error.message").String() - errMsgLower := strings.ToLower(strings.TrimSpace(errMsg)) - errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String())) - errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String())) - errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String())) - - if errParam == "instructions" { - return true - } - if hasInstructionRequired(errMsg) { - return true - } - if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") { - return true - } - if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") { - return true - } - - return false -} - -func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { - if upstreamStatusCode != http.StatusBadRequest { - return false - } - - match := func(text string) bool { - lower := strings.ToLower(strings.TrimSpace(text)) - if lower == "" { - return false - } - if strings.Contains(lower, "an error occurred while processing your request") { - return true - } - return strings.Contains(lower, "you can retry your request") && - strings.Contains(lower, "help.openai.com") && - strings.Contains(lower, "request id") - } - - if match(upstreamMsg) { - return true - } - if len(upstreamBody) == 0 { - return false - } - if match(gjson.GetBytes(upstreamBody, "error.message").String()) { - return true - } - return match(string(upstreamBody)) -} - -// ExtractSessionID extracts the raw session ID from headers or body without hashing. -// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache. -func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string { - if c == nil { - return "" - } - sessionID := strings.TrimSpace(c.GetHeader("session_id")) - if sessionID == "" { - sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) - } - if sessionID == "" && len(body) > 0 { - sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) - } - return sessionID -} - -func explicitOpenAISessionID(c *gin.Context, body []byte) string { - if c == nil { - return "" - } - - sessionID := strings.TrimSpace(c.GetHeader("session_id")) - if sessionID == "" { - sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) - } - if sessionID == "" && len(body) > 0 { - sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) - } - return sessionID -} - -// GenerateExplicitSessionHash generates a sticky-session hash only from explicit -// client session signals. It intentionally skips content-derived fallback and is -// used by stateless endpoints such as /v1/images. -func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string { - sessionID := explicitOpenAISessionID(c, body) - if sessionID == "" { - return "" - } - - currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) - attachOpenAILegacySessionHashToGin(c, legacyHash) - return currentHash -} - -// GenerateSessionHash generates a sticky-session hash for OpenAI requests. -// -// Priority: -// 1. Header: session_id -// 2. Header: conversation_id -// 3. Body: prompt_cache_key (opencode) -// 4. Body: content-based fallback (model + system + tools + first user message) -func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { - if c == nil { - return "" - } - - sessionID := explicitOpenAISessionID(c, body) - if sessionID == "" && len(body) > 0 { - sessionID = deriveOpenAIContentSessionSeed(body) - } - if sessionID == "" { - return "" - } - - currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) - attachOpenAILegacySessionHashToGin(c, legacyHash) - return currentHash -} - -// GenerateSessionHashWithFallback 先按常规信号生成会话哈希; -// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。 -// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。 -func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string { - sessionHash := s.GenerateSessionHash(c, body) - if sessionHash != "" { - return sessionHash - } - - seed := strings.TrimSpace(fallbackSeed) - if seed == "" { - return "" - } - - currentHash, legacyHash := deriveOpenAISessionHashes(seed) - attachOpenAILegacySessionHashToGin(c, legacyHash) - return currentHash -} - -func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string { - if c != nil { - if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" { - return originator - } - } - if isOfficialClient { - return "codex_cli_rs" - } - return "opencode" -} - // BindStickySession sets session -> account binding with standard TTL. func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 { @@ -1259,2556 +776,903 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, false, 0) } -// noAvailableOpenAISelectionError builds the standard "no account available" error -// while preserving the compact-specific error when applicable. -func noAvailableOpenAISelectionError(requestedModel string, compactBlocked bool) error { - if compactBlocked { - return ErrNoAvailableCompactAccounts - } - if requestedModel != "" { - return fmt.Errorf("no available OpenAI accounts supporting model: %s", requestedModel) - } - return errors.New("no available OpenAI accounts") +// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. +func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { + return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false) } -// openAICompactSupportTier classifies an OpenAI account by compact capability. -// 0 = explicitly unsupported, 1 = unknown / not yet probed, 2 = explicitly supported. -func openAICompactSupportTier(account *Account) int { - if account == nil || !account.IsOpenAI() { - return 0 - } - supported, known := account.OpenAICompactSupportKnown() - if !known { - return 1 - } - if supported { - return 2 +// GetAccessToken gets the access token for an OpenAI account +func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { + switch account.Type { + case AccountTypeOAuth: + // 使用 TokenProvider 获取缓存的 token + if s.openAITokenProvider != nil { + accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account) + if err != nil { + return "", "", err + } + return accessToken, "oauth", nil + } + // 降级:TokenProvider 未配置时直接从账号读取 + accessToken := account.GetOpenAIAccessToken() + if accessToken == "" { + return "", "", errors.New("access_token not found in credentials") + } + return accessToken, "oauth", nil + case AccountTypeAPIKey: + apiKey := account.GetOpenAIApiKey() + if apiKey == "" { + return "", "", errors.New("api_key not found in credentials") + } + return apiKey, "apikey", nil + default: + return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } - return 0 } -// isOpenAIAccountEligibleForRequest centralises the schedulable / OpenAI / model / -// compact-support checks used during account selection. -func isOpenAIAccountEligibleForRequest(account *Account, requestedModel string, requireCompact bool) bool { - if account == nil || !account.IsSchedulable() || !account.IsOpenAI() { - return false - } - if requestedModel != "" && !account.IsModelSupported(requestedModel) { - return false - } - if requireCompact && openAICompactSupportTier(account) == 0 { - return false +func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 429, 529: + return true + default: + return statusCode >= 500 } - return true } -// prioritizeOpenAICompactAccounts re-orders a slice so that accounts with known -// compact support are tried first, followed by unknown, then explicitly unsupported. -// The relative order within each tier is preserved. -func prioritizeOpenAICompactAccounts(accounts []*Account) []*Account { - if len(accounts) == 0 { - return nil - } - supported := make([]*Account, 0, len(accounts)) - unknown := make([]*Account, 0, len(accounts)) - unsupported := make([]*Account, 0, len(accounts)) - for _, account := range accounts { - switch openAICompactSupportTier(account) { - case 2: - supported = append(supported, account) - case 1: - unknown = append(unknown, account) - default: - unsupported = append(unsupported, account) - } +func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool { + if s.shouldFailoverUpstreamError(statusCode) { + return true } - out := make([]*Account, 0, len(accounts)) - out = append(out, supported...) - out = append(out, unknown...) - out = append(out, unsupported...) - return out + return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody) } -// resolveOpenAIAccountUpstreamModelForRequest resolves the upstream model that -// would be sent for a given request, honouring compact-only mappings when the -// caller is on the /responses/compact path. -func resolveOpenAIAccountUpstreamModelForRequest(account *Account, requestedModel string, requireCompact bool) string { - upstreamModel := resolveOpenAIForwardModel(account, requestedModel, "") - if upstreamModel == "" { - return "" - } - if requireCompact { - return resolveOpenAICompactForwardModel(account, upstreamModel) - } - return upstreamModel +func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } -func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) (*Account, error) { - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - - // 1. 尝试粘性会话命中 - // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID); account != nil { - return account, nil - } +// Forward forwards request to OpenAI API +// Forward 将请求转发到上游 OpenAI API。 +// 流程:访问控制 → 协议选择 → 请求体转换 → 模型映射 → 序列化 → 执行 → 响应处理。 +// 支持 HTTP SSE 和 WebSocket v2 两种传输模式。 +func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { + startTime := time.Now() - // 2. 获取可调度的 OpenAI 账号 - // Get schedulable OpenAI accounts - accounts, err := s.listSchedulableAccounts(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) + // ─── 阶段1: 访问控制(Codex CLI 限制检测)─── + restrictionResult := s.detectCodexClientRestriction(c, account) + apiKeyID := getAPIKeyIDFromContext(c) + logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body) + if restrictionResult.Enabled && !restrictionResult.Matched { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": "This account only allows Codex official clients", + }, + }) + return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") } - // 3. 按优先级 + LRU 选择最佳账号 - // Select by priority + LRU - selected, compactBlocked := s.selectBestAccount(ctx, groupID, accounts, requestedModel, excludedIDs, requireCompact) - - if selected == nil { - return nil, noAvailableOpenAISelectionError(requestedModel, compactBlocked) - } + // ─── 阶段2: 请求元数据提取与协议选择 ─── + originalBody := body + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + originalModel := reqModel + compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body) + setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge) - // 4. 设置粘性会话绑定 - // Set sticky session binding - if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) + isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + clientTransport := GetOpenAIClientTransport(c) + // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 + wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) + if c != nil { + c.Set("openai_ws_transport_decision", string(wsDecision.Transport)) + c.Set("openai_ws_transport_reason", wsDecision.Reason) } - - return s.hydrateSelectedAccount(ctx, selected) -} - -// tryStickySessionHit 尝试从粘性会话获取账号。 -// 如果命中且账号可用则返回账号;如果账号不可用则清理会话并返回 nil。 -// -// tryStickySessionHit attempts to get account from sticky session. -// Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool, stickyAccountID int64) *Account { - if sessionHash == "" { - return nil + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeDebug( + "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + reqModel, + reqStream, + ) } - - accountID := stickyAccountID - if accountID <= 0 { - var err error - accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash) - if err != nil || accountID <= 0 { - return nil + // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + if c != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", + }, + }) } + return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2") } - - if _, excluded := excludedIDs[accountID]; excluded { - return nil + passthroughEnabled := account.IsOpenAIPassthroughEnabled() + if passthroughEnabled { + // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) + return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) } - account, err := s.getSchedulableAccount(ctx, accountID) + // ─── 阶段3: 请求体解析与字段转换 ─── + reqBody, err := getOpenAIRequestBodyMap(c, body) if err != nil { - return nil + return nil, err } - // 检查账号是否需要清理粘性会话 - // Check if sticky session should be cleared - if shouldClearStickySession(account, requestedModel) { - _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) - return nil + if v, ok := reqBody["model"].(string); ok { + reqModel = v + originalModel = reqModel } - - // 验证账号是否可用于当前请求 - // Verify account is usable for current request - if !isOpenAIAccountEligibleForRequest(account, requestedModel, false) { - return nil + if v, ok := reqBody["stream"].(bool); ok { + reqStream = v } - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) - if account == nil { - _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) - return nil + if promptCacheKey == "" { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } } - if groupID != nil && s.needsUpstreamChannelRestrictionCheck(ctx, groupID) && - s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { - _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) - return nil + apiKey := getAPIKeyFromContext(c) + imageGenerationAllowed := GroupAllowsImageGeneration(nil) + if apiKey != nil { + imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group) + } + codexImageGenerationBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey) + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") } - // 刷新会话 TTL 并返回账号 - // Refresh session TTL and return account - _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return account -} - -// selectBestAccount 从候选账号中选择最佳账号(优先级 + LRU)。 -// 返回 nil 表示无可用账号。 -// -// selectBestAccount selects the best account from candidates (priority + LRU). -// Returns nil if no available account. The second return reports whether at -// least one candidate was filtered out solely because it lacks compact support -// (only meaningful when requireCompact=true). -func (s *OpenAIGatewayService) selectBestAccount(ctx context.Context, groupID *int64, accounts []Account, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*Account, bool) { - var selected *Account - selectedCompactTier := -1 - compactBlocked := false - needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) - - for i := range accounts { - acc := &accounts[i] - - // 跳过被排除的账号 - // Skip excluded accounts - if _, excluded := excludedIDs[acc.ID]; excluded { - continue + // Track if body needs re-serialization + bodyModified := false + // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。 + patchDisabled := false + patchHasOp := false + patchDelete := false + patchPath := "" + var patchValue any + markPatchSet := func(path string, value any) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return } - - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) - if fresh == nil { - continue + if patchDisabled { + return } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, false) - if fresh == nil { - continue + if !patchHasOp { + patchHasOp = true + patchDelete = false + patchPath = path + patchValue = value + return } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { - continue + if patchDelete || patchPath != path { + patchDisabled = true + return } - compactTier := 0 - if requireCompact { - compactTier = openAICompactSupportTier(fresh) - if compactTier == 0 { - compactBlocked = true - continue - } + patchValue = value + } + markPatchDelete := func(path string) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return } - - // 选择优先级最高且最久未使用的账号 - // Select highest priority and least recently used - if selected == nil { - selected = fresh - selectedCompactTier = compactTier - continue + if patchDisabled { + return } - - // compact 模式下高 tier 优先;同 tier 内才比较 priority/LRU。 - if requireCompact && compactTier != selectedCompactTier { - if compactTier > selectedCompactTier { - selected = fresh - selectedCompactTier = compactTier - } - continue + if !patchHasOp { + patchHasOp = true + patchDelete = true + patchPath = path + return } - - if s.isBetterAccount(fresh, selected) { - selected = fresh - selectedCompactTier = compactTier + if !patchDelete || patchPath != path { + patchDisabled = true } } - - return selected, compactBlocked -} - -// isBetterAccount 判断 candidate 是否比 current 更优。 -// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。 -// -// isBetterAccount checks if candidate is better than current. -// Rules: higher priority (lower value) wins; same priority: never used > least recently used. -func (s *OpenAIGatewayService) isBetterAccount(candidate, current *Account) bool { - // 优先级更高(数值更小) - // Higher priority (lower value) - if candidate.Priority < current.Priority { - return true - } - if candidate.Priority > current.Priority { - return false + disablePatch := func() { + patchDisabled = true } - // 同优先级,比较最后使用时间 - // Same priority, compare last used time - switch { - case candidate.LastUsedAt == nil && current.LastUsedAt != nil: - // candidate 从未使用,优先 - return true - case candidate.LastUsedAt != nil && current.LastUsedAt == nil: - // current 从未使用,保持 - return false - case candidate.LastUsedAt == nil && current.LastUsedAt == nil: - // 都未使用,保持 - return false - default: - // 都使用过,选择最久未使用的 - return candidate.LastUsedAt.Before(*current.LastUsedAt) + // 非透传模式下,instructions 为空时注入默认指令。 + if isInstructionsEmpty(reqBody) && !compatMessagesBridge { + reqBody["instructions"] = "You are a helpful coding assistant." + bodyModified = true + markPatchSet("instructions", "You are a helpful coding assistant.") } -} -// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan. -func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) { - return s.selectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs, false) -} + if codexImageGenerationBridgeEnabled && ensureOpenAIResponsesImageGenerationTool(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") + } -func (s *OpenAIGatewayService) selectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, requireCompact bool) (*AccountSelectionResult, error) { - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + if normalizeOpenAIResponsesImageGenerationTools(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") + } + if codexImageGenerationBridgeEnabled && applyCodexImageGenerationBridgeInstructions(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") } - cfg := s.schedulingConfig() - needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { - stickyAccountID = accountID - } + // ─── 阶段4: 模型映射与 OAuth 转换 ─── + billingModel := account.GetMappedModel(reqModel) + if billingModel != reqModel { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI) + reqBody["model"] = billingModel + bodyModified = true + markPatchSet("model", billingModel) } - if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, requireCompact, stickyAccountID) - if err != nil { - return nil, err - } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) - } + upstreamModel := billingModel + if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) { + bodyModified = true + disablePatch() + if model, ok := reqBody["model"].(string); ok { + upstreamModel = strings.TrimSpace(model) } - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Normalized /responses image-only model request inbound_model=%s image_model=%s upstream_model=%s", + reqModel, + billingModel, + upstreamModel, + ) } - - accounts, err := s.listSchedulableAccounts(ctx, groupID) - if err != nil { + if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil { + setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + "param": "model", + }, + }) return nil, err } - if len(accounts) == 0 { - return nil, ErrNoAvailableAccounts + if hasOpenAIImageGenerationTool(reqBody) { + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s", + reqModel, + upstreamModel, + account.Type, + ) } - - isExcluded := func(accountID int64) bool { - if excludedIDs == nil { - return false - } - _, excluded := excludedIDs[accountID] - return excluded + if err := validateCodexSparkInput(reqBody, upstreamModel); err != nil { + setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + "param": "input", + }, + }) + return nil, err } - // ============ Layer 1: Sticky session ============ - if sessionHash != "" { - accountID := stickyAccountID - if accountID > 0 && !isExcluded(accountID) { - account, err := s.getSchedulableAccount(ctx, accountID) - if err == nil { - clearSticky := shouldClearStickySession(account, requestedModel) - if clearSticky { - _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) - } - if !clearSticky && isOpenAIAccountEligibleForRequest(account, requestedModel, false) { - account = s.recheckSelectedOpenAIAccountFromDB(ctx, account, requestedModel, requireCompact) - if account == nil { - _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) - } else if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel, requireCompact) { - _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) - } else { - result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) - if err == nil && result.Acquired { - _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) - } - - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) - if waitingCount < cfg.StickySessionMaxWaiting { - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) - } - } - } - } + // Compact-only model 映射:仅在 /responses/compact 路径生效,且优先级高于 + // OAuth 模型规范化(避免 OAuth 规范化覆盖 compact-only 自定义模型)。 + isCompactRequest := isOpenAIResponsesCompactPath(c) + compactMapped := false + if isCompactRequest { + compactMappedModel := resolveOpenAICompactForwardModel(account, billingModel) + if compactMappedModel != "" && compactMappedModel != billingModel { + compactMapped = true + upstreamModel = compactMappedModel + reqBody["model"] = compactMappedModel + bodyModified = true + markPatchSet("model", compactMappedModel) + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Compact model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", billingModel, compactMappedModel, account.Name, isCodexCLI) } } - // ============ Layer 2: Load-aware selection ============ - baseCandidateCount := 0 - candidates := make([]*Account, 0, len(accounts)) - for i := range accounts { - acc := &accounts[i] - if isExcluded(acc.ID) { - continue - } - // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); - // re-check schedulability here so recently rate-limited/overloaded accounts - // are not selected again before the bucket is rebuilt. - if !acc.IsSchedulable() { - continue - } - if requestedModel != "" && !acc.IsModelSupported(requestedModel) { - continue + // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 + // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, + // 以兼容自定义 base_url 的 OpenAI-compatible 上游。 + if model, ok := reqBody["model"].(string); ok { + if !compactMapped { + upstreamModel = normalizeOpenAIModelForUpstream(account, model) + if upstreamModel != "" && upstreamModel != model { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + model, upstreamModel, account.Name, account.Type, isCodexCLI) + reqBody["model"] = upstreamModel + bodyModified = true + markPatchSet("model", upstreamModel) + } } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel, requireCompact) { - continue + + // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 + // 确保高版本模型向低版本模型映射不报错 + if !SupportsVerbosity(upstreamModel) { + if text, ok := reqBody["text"].(map[string]any); ok { + delete(text, "verbosity") + } } - baseCandidateCount++ - candidates = append(candidates, acc) } - if len(candidates) == 0 { - return nil, ErrNoAvailableAccounts + // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 + if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { + if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { + reasoning["effort"] = "none" + bodyModified = true + markPatchSet("reasoning.effort", "none") + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) + } } - accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) - for _, acc := range candidates { - accountLoads = append(accountLoads, AccountWithConcurrency{ - ID: acc.ID, - MaxConcurrency: acc.EffectiveLoadFactor(), - }) + if account.Type == AccountTypeOAuth { + codexResult := codexTransformResult{} + if compatMessagesBridge { + codexResult = applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{ + IsCodexCLI: isCodexCLI, + IsCompact: isCompactRequest, + SkipDefaultInstructions: true, + PreserveToolCallIDs: true, + }) + ensureCodexOAuthInstructionsField(reqBody) + bodyModified = true + disablePatch() + } else { + codexResult = applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest) + } + if codexResult.Modified { + bodyModified = true + disablePatch() + } + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel + } + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } } - loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) - if err != nil { - ordered := append([]*Account(nil), candidates...) - sortAccountsByPriorityAndLastUsed(ordered, false) - if requireCompact { - ordered = prioritizeOpenAICompactAccounts(ordered) - } - for _, acc := range ordered { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) - if fresh == nil { - continue - } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) - if fresh == nil { - continue - } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { - continue - } - result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) + // Handle max_output_tokens based on platform and account type + if !isCodexCLI { + if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens { + switch account.Platform { + case PlatformOpenAI: + // For OpenAI API Key, remove max_output_tokens (not supported) + // For OpenAI OAuth (Responses API), keep it (supported) + if account.Type == AccountTypeAPIKey { + delete(reqBody, "max_output_tokens") + bodyModified = true + markPatchDelete("max_output_tokens") + } + case PlatformAnthropic: + // For Anthropic (Claude), convert to max_tokens + delete(reqBody, "max_output_tokens") + markPatchDelete("max_output_tokens") + if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { + reqBody["max_tokens"] = maxOutputTokens + disablePatch() } - return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + bodyModified = true + case PlatformGemini: + // For Gemini, remove (will be handled by Gemini-specific transform) + delete(reqBody, "max_output_tokens") + bodyModified = true + markPatchDelete("max_output_tokens") + default: + // For unknown platforms, remove to be safe + delete(reqBody, "max_output_tokens") + bodyModified = true + markPatchDelete("max_output_tokens") } } - } else { - var available []accountWithLoad - for _, acc := range candidates { - loadInfo := loadMap[acc.ID] - if loadInfo == nil { - loadInfo = &AccountLoadInfo{AccountID: acc.ID} - } - if loadInfo.LoadRate < 100 { - available = append(available, accountWithLoad{ - account: acc, - loadInfo: loadInfo, - }) + + // Also handle max_completion_tokens (similar logic) + if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens { + if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { + delete(reqBody, "max_completion_tokens") + bodyModified = true + markPatchDelete("max_completion_tokens") } } - if len(available) > 0 { - sort.SliceStable(available, func(i, j int) bool { - a, b := available[i], available[j] - if a.account.Priority != b.account.Priority { - return a.account.Priority < b.account.Priority - } - if a.loadInfo.LoadRate != b.loadInfo.LoadRate { - return a.loadInfo.LoadRate < b.loadInfo.LoadRate - } - switch { - case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: - return true - case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: - return false - case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: - return false - default: - return a.account.LastUsedAt.Before(*b.account.LastUsedAt) - } - }) - shuffleWithinSortGroups(available) - - selectionOrder := make([]accountWithLoad, 0, len(available)) - if requireCompact { - appendTier := func(out []accountWithLoad, tier int) []accountWithLoad { - for _, item := range available { - if openAICompactSupportTier(item.account) == tier { - out = append(out, item) - } - } - return out - } - selectionOrder = appendTier(selectionOrder, 2) - selectionOrder = appendTier(selectionOrder, 1) - // tier 0 候选作为兜底追加:DB recheck 时若发现 cache tier 0 实际 - // 已升级为 1/2(探测刚跑完,cache 尚未刷新),仍可正常命中。 - selectionOrder = appendTier(selectionOrder, 0) - } else { - selectionOrder = append(selectionOrder, available...) + // Remove unsupported fields (not supported by upstream OpenAI API) + unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"} + for _, unsupportedField := range unsupportedFields { + if _, has := reqBody[unsupportedField]; has { + delete(reqBody, unsupportedField) + bodyModified = true + markPatchDelete(unsupportedField) } + } + } - for _, item := range selectionOrder { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, item.account, requestedModel, false) - if fresh == nil { - continue - } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) - if fresh == nil { - continue - } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { - continue + // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。 + // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。 + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + if _, has := reqBody["previous_response_id"]; has { + delete(reqBody, "previous_response_id") + bodyModified = true + markPatchDelete("previous_response_id") + } + } + + if sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody) { + bodyModified = true + disablePatch() + } + + // Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤): + // 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略 + // 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽 + // fast 时在此生效。 + // + // 注意: + // 1. 此处统一使用 upstreamModel(已经过 GetMappedModel + + // normalizeOpenAIModelForUpstream + Codex OAuth normalize),与 + // chat-completions / messages 入口保持一致,避免不同入口因为模型 + // 维度不同而出现 whitelist 命中差异。 + // 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body, + // 否则 native /responses 入口透传 "fast" 给上游会被拒。chat- + // completions 入口由 normalizeResponsesBodyServiceTier 完成同一 + // 行为,这里手工实现等效逻辑。 + if rawTier, ok := reqBody["service_tier"].(string); ok { + if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" { + action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier) + switch action { + case BetaPolicyActionBlock: + msg := errMsg + if msg == "" { + msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel) } - result, err := s.tryAcquireAccountSlot(ctx, fresh.ID, fresh.Concurrency) - if err == nil && result.Acquired { - if sessionHash != "" { - _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, fresh.ID, openaiStickySessionTTL) - } - return s.newSelectionResult(ctx, fresh, true, result.ReleaseFunc, nil) + blocked := &OpenAIFastBlockedError{Message: msg} + writeOpenAIFastPolicyBlockedResponse(c, blocked) + return nil, blocked + case BetaPolicyActionFilter: + delete(reqBody, "service_tier") + bodyModified = true + disablePatch() + default: + // pass:若客户端传的是别名 "fast",归一化为 "priority" + // 后写回 body,确保上游收到的是其能识别的规范值。 + if normTier != rawTier { + reqBody["service_tier"] = normTier + bodyModified = true + markPatchSet("service_tier", normTier) } } } } - // ============ Layer 3: Fallback wait ============ - sortAccountsByPriorityAndLastUsed(candidates, false) - if requireCompact { - candidates = prioritizeOpenAICompactAccounts(candidates) + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { + setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "permission_error", + "message": ImageGenerationPermissionMessage(), + }, + }) + return nil, errors.New("image generation disabled for group") } - for _, acc := range candidates { - fresh := s.resolveFreshSchedulableOpenAIAccount(ctx, acc, requestedModel, false) - if fresh == nil { - continue - } - fresh = s.recheckSelectedOpenAIAccountFromDB(ctx, fresh, requestedModel, requireCompact) - if fresh == nil { - continue - } - if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, fresh, requestedModel, requireCompact) { - continue + imageBillingModel := "" + imageSizeTier := "" + if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) { + var imageCfgErr error + imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel) + if imageCfgErr != nil { + setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": imageCfgErr.Error(), + "param": "size", + }, + }) + return nil, imageCfgErr } - return s.newSelectionResult(ctx, fresh, false, nil, &AccountWaitPlan{ - AccountID: fresh.ID, - MaxConcurrency: fresh.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) } - if requireCompact && baseCandidateCount > 0 { - return nil, ErrNoAvailableCompactAccounts + // Re-serialize body only if modified + if bodyModified { + serializedByPatch := false + if !patchDisabled && patchHasOp { + var patchErr error + if patchDelete { + body, patchErr = sjson.DeleteBytes(body, patchPath) + } else { + body, patchErr = sjson.SetBytes(body, patchPath, patchValue) + } + if patchErr == nil { + serializedByPatch = true + } + } + if !serializedByPatch { + var marshalErr error + body, marshalErr = json.Marshal(reqBody) + if marshalErr != nil { + return nil, fmt.Errorf("serialize request body: %w", marshalErr) + } + } } - return nil, ErrNoAvailableAccounts -} -func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) { - if s.schedulerSnapshot != nil { - accounts, _, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, PlatformOpenAI, false) - return accounts, err - } - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI) - } else if groupID != nil { - accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI) - } else { - accounts, err = s.accountRepo.ListSchedulableUngroupedByPlatform(ctx, PlatformOpenAI) - } + // ─── 阶段5: 凭证获取与协议执行 ─── + token, _, err := s.GetAccessToken(ctx, account) if err != nil { - return nil, fmt.Errorf("query accounts failed: %w", err) + return nil, err } - return accounts, nil -} -func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { - if s.concurrencyService == nil { - return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil - } - return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) -} - -func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { - if account == nil { - return nil - } - - fresh := account - if s.schedulerSnapshot != nil { - current, err := s.getSchedulableAccount(ctx, account.ID) - if err != nil || current == nil { - return nil - } - fresh = current - } - - if !isOpenAIAccountEligibleForRequest(fresh, requestedModel, requireCompact) { - return nil - } - return fresh -} - -func (s *OpenAIGatewayService) recheckSelectedOpenAIAccountFromDB(ctx context.Context, account *Account, requestedModel string, requireCompact bool) *Account { - if account == nil { - return nil - } - if s.schedulerSnapshot == nil || s.accountRepo == nil { - if !isOpenAIAccountEligibleForRequest(account, requestedModel, requireCompact) { - return nil - } - return account - } - - latest, err := s.accountRepo.GetByID(ctx, account.ID) - if err != nil || latest == nil { - return nil - } - if !isOpenAIAccountEligibleForRequest(latest, requestedModel, requireCompact) { - return nil - } - return latest -} - -func (s *OpenAIGatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { - var ( - account *Account - err error - ) - if s.schedulerSnapshot != nil { - account, err = s.schedulerSnapshot.GetAccount(ctx, accountID) - } else { - account, err = s.accountRepo.GetByID(ctx, accountID) - } - if err != nil || account == nil { - return account, err - } - return account, nil -} - -func (s *OpenAIGatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { - if account == nil || s.schedulerSnapshot == nil { - return account, nil - } - hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) - if err != nil { - return nil, err - } - if hydrated == nil { - return nil, fmt.Errorf("selected openai account %d not found during hydration", account.ID) - } - return hydrated, nil -} - -func (s *OpenAIGatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { - hydrated, err := s.hydrateSelectedAccount(ctx, account) - if err != nil { - return nil, err - } - return &AccountSelectionResult{ - Account: hydrated, - Acquired: acquired, - ReleaseFunc: release, - WaitPlan: waitPlan, - }, nil -} - -func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig { - if s.cfg != nil { - return s.cfg.Gateway.Scheduling - } - return config.GatewaySchedulingConfig{ - StickySessionMaxWaiting: 3, - StickySessionWaitTimeout: 45 * time.Second, - FallbackWaitTimeout: 30 * time.Second, - FallbackMaxWaiting: 100, - LoadBatchEnabled: true, - SlotCleanupInterval: 30 * time.Second, - } -} + // Capture upstream request body for ops retry of this attempt. + setOpsUpstreamRequestBody(c, body) -// GetAccessToken gets the access token for an OpenAI account -func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { - switch account.Type { - case AccountTypeOAuth: - // 使用 TokenProvider 获取缓存的 token - if s.openAITokenProvider != nil { - accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account) - if err != nil { - return "", "", err + // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + wsReqBody := reqBody + if len(reqBody) > 0 { + wsReqBody = make(map[string]any, len(reqBody)) + for k, v := range reqBody { + wsReqBody[k] = v } - return accessToken, "oauth", nil - } - // 降级:TokenProvider 未配置时直接从账号读取 - accessToken := account.GetOpenAIAccessToken() - if accessToken == "" { - return "", "", errors.New("access_token not found in credentials") } - return accessToken, "oauth", nil - case AccountTypeAPIKey: - apiKey := account.GetOpenAIApiKey() - if apiKey == "" { - return "", "", errors.New("api_key not found in credentials") - } - return apiKey, "apikey", nil - default: - return "", "", fmt.Errorf("unsupported account type: %s", account.Type) - } -} - -func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool { - switch statusCode { - case 401, 402, 403, 429, 529: - return true - default: - return statusCode >= 500 - } -} - -func (s *OpenAIGatewayService) shouldFailoverOpenAIUpstreamResponse(statusCode int, upstreamMsg string, upstreamBody []byte) bool { - if s.shouldFailoverUpstreamError(statusCode) { - return true - } - return isOpenAITransientProcessingError(statusCode, upstreamMsg, upstreamBody) -} - -func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) -} - -// Forward forwards request to OpenAI API -func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { - startTime := time.Now() - - restrictionResult := s.detectCodexClientRestriction(c, account) - apiKeyID := getAPIKeyIDFromContext(c) - logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body) - if restrictionResult.Enabled && !restrictionResult.Matched { - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "type": "forbidden_error", - "message": "This account only allows Codex official clients", - }, - }) - return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") - } - - originalBody := body - reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) - originalModel := reqModel - compatMessagesBridge := isOpenAICompatMessagesBridgeBody(body) - setOpenAICompatMessagesBridgeContext(c, compatMessagesBridge) - - isCodexCLI := openai.IsCodexOfficialClientByHeaders(c.GetHeader("User-Agent"), c.GetHeader("originator")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) - wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) - clientTransport := GetOpenAIClientTransport(c) - // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 - wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) - if c != nil { - c.Set("openai_ws_transport_decision", string(wsDecision.Transport)) - c.Set("openai_ws_transport_reason", wsDecision.Reason) - } - if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + _, hasPreviousResponseID := wsReqBody["previous_response_id"] logOpenAIWSModeDebug( - "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v", + "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", account.ID, account.Type, - normalizeOpenAIWSLogValue(string(wsDecision.Transport)), - normalizeOpenAIWSLogValue(wsDecision.Reason), - reqModel, + upstreamModel, reqStream, + hasPreviousResponseID, ) - } - // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 - if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { - if c != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", - }, - }) + maxAttempts := openAIWSReconnectRetryLimit + 1 + wsAttempts := 0 + var wsResult *OpenAIForwardResult + var wsErr error + wsLastFailureReason := "" + wsPrevResponseRecoveryTried := false + wsInvalidEncryptedContentRecoveryTried := false + recoverPrevResponseNotFound := func(attempt int) bool { + if wsPrevResponseRecoveryTried { + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + if previousResponseID == "" { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false", + account.ID, + attempt, + ) + return false + } + if HasFunctionCallOutput(wsReqBody) { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true", + account.ID, + attempt, + ) + return false + } + delete(wsReqBody, "previous_response_id") + wsPrevResponseRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s", + account.ID, + attempt, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + ) + return true } - return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2") - } - passthroughEnabled := account.IsOpenAIPassthroughEnabled() - if passthroughEnabled { - // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 - reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) - return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) - } - - reqBody, err := getOpenAIRequestBodyMap(c, body) - if err != nil { - return nil, err - } - - if v, ok := reqBody["model"].(string); ok { - reqModel = v - originalModel = reqModel - } - if v, ok := reqBody["stream"].(bool); ok { - reqStream = v - } - if promptCacheKey == "" { - if v, ok := reqBody["prompt_cache_key"].(string); ok { - promptCacheKey = strings.TrimSpace(v) + recoverInvalidEncryptedContent := func(attempt int) bool { + if wsInvalidEncryptedContentRecoveryTried { + return false + } + removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody) + if !removedReasoningItems { + logOpenAIWSModeInfo( + "reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items", + account.ID, + attempt, + ) + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody) + if previousResponseID != "" && !hasFunctionCallOutput { + delete(wsReqBody, "previous_response_id") + } + wsInvalidEncryptedContentRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v", + account.ID, + attempt, + previousResponseID != "", + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + hasFunctionCallOutput, + previousResponseID != "" && !hasFunctionCallOutput, + ) + return true } - } - apiKey := getAPIKeyFromContext(c) - imageGenerationAllowed := GroupAllowsImageGeneration(nil) - if apiKey != nil { - imageGenerationAllowed = GroupAllowsImageGeneration(apiKey.Group) - } - codexImageGenerationBridgeEnabled := isCodexCLI && imageGenerationAllowed && s.isCodexImageGenerationBridgeEnabled(ctx, account, apiKey) - if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { - setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "type": "permission_error", - "message": ImageGenerationPermissionMessage(), - }, - }) - return nil, errors.New("image generation disabled for group") - } - - // Track if body needs re-serialization - bodyModified := false - // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。 - patchDisabled := false - patchHasOp := false - patchDelete := false - patchPath := "" - var patchValue any - markPatchSet := func(path string, value any) { - if strings.TrimSpace(path) == "" { - patchDisabled = true - return - } - if patchDisabled { - return - } - if !patchHasOp { - patchHasOp = true - patchDelete = false - patchPath = path - patchValue = value - return - } - if patchDelete || patchPath != path { - patchDisabled = true - return - } - patchValue = value - } - markPatchDelete := func(path string) { - if strings.TrimSpace(path) == "" { - patchDisabled = true - return - } - if patchDisabled { - return - } - if !patchHasOp { - patchHasOp = true - patchDelete = true - patchPath = path - return - } - if !patchDelete || patchPath != path { - patchDisabled = true - } - } - disablePatch := func() { - patchDisabled = true - } - - // 非透传模式下,instructions 为空时注入默认指令。 - if isInstructionsEmpty(reqBody) && !compatMessagesBridge { - reqBody["instructions"] = "You are a helpful coding assistant." - bodyModified = true - markPatchSet("instructions", "You are a helpful coding assistant.") - } - - if codexImageGenerationBridgeEnabled && ensureOpenAIResponsesImageGenerationTool(reqBody) { - bodyModified = true - disablePatch() - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Injected /responses image_generation tool for Codex client") - } - - if normalizeOpenAIResponsesImageGenerationTools(reqBody) { - bodyModified = true - disablePatch() - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") - } - if codexImageGenerationBridgeEnabled && applyCodexImageGenerationBridgeInstructions(reqBody) { - bodyModified = true - disablePatch() - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Added Codex image_generation bridge instructions") - } - - // 对所有请求执行模型映射(包含 Codex CLI)。 - billingModel := account.GetMappedModel(reqModel) - if billingModel != reqModel { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, billingModel, account.Name, isCodexCLI) - reqBody["model"] = billingModel - bodyModified = true - markPatchSet("model", billingModel) - } - upstreamModel := billingModel - if imageGenerationAllowed && normalizeOpenAIResponsesImageOnlyModel(reqBody) { - bodyModified = true - disablePatch() - if model, ok := reqBody["model"].(string); ok { - upstreamModel = strings.TrimSpace(model) - } - logger.LegacyPrintf( - "service.openai_gateway", - "[OpenAI] Normalized /responses image-only model request inbound_model=%s image_model=%s upstream_model=%s", - reqModel, - billingModel, - upstreamModel, - ) - } - if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil { - setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": err.Error(), - "param": "model", - }, - }) - return nil, err - } - if hasOpenAIImageGenerationTool(reqBody) { - logger.LegacyPrintf( - "service.openai_gateway", - "[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s", - reqModel, - upstreamModel, - account.Type, - ) - } - if err := validateCodexSparkInput(reqBody, upstreamModel); err != nil { - setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": err.Error(), - "param": "input", - }, - }) - return nil, err - } - - // Compact-only model 映射:仅在 /responses/compact 路径生效,且优先级高于 - // OAuth 模型规范化(避免 OAuth 规范化覆盖 compact-only 自定义模型)。 - isCompactRequest := isOpenAIResponsesCompactPath(c) - compactMapped := false - if isCompactRequest { - compactMappedModel := resolveOpenAICompactForwardModel(account, billingModel) - if compactMappedModel != "" && compactMappedModel != billingModel { - compactMapped = true - upstreamModel = compactMappedModel - reqBody["model"] = compactMappedModel - bodyModified = true - markPatchSet("model", compactMappedModel) - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Compact model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", billingModel, compactMappedModel, account.Name, isCodexCLI) - } - } - - // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 - // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, - // 以兼容自定义 base_url 的 OpenAI-compatible 上游。 - if model, ok := reqBody["model"].(string); ok { - if !compactMapped { - upstreamModel = normalizeOpenAIModelForUpstream(account, model) - if upstreamModel != "" && upstreamModel != model { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Upstream model resolved: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", - model, upstreamModel, account.Name, account.Type, isCodexCLI) - reqBody["model"] = upstreamModel - bodyModified = true - markPatchSet("model", upstreamModel) - } - } - - // 移除 gpt-5.2-codex 以下的版本 verbosity 参数 - // 确保高版本模型向低版本模型映射不报错 - if !SupportsVerbosity(upstreamModel) { - if text, ok := reqBody["text"].(map[string]any); ok { - delete(text, "verbosity") - } - } - } - - // 规范化 reasoning.effort 参数(minimal -> none),与上游允许值对齐。 - if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { - if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { - reasoning["effort"] = "none" - bodyModified = true - markPatchSet("reasoning.effort", "none") - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) - } - } - - if account.Type == AccountTypeOAuth { - codexResult := codexTransformResult{} - if compatMessagesBridge { - codexResult = applyCodexOAuthTransformWithOptions(reqBody, codexOAuthTransformOptions{ - IsCodexCLI: isCodexCLI, - IsCompact: isCompactRequest, - SkipDefaultInstructions: true, - PreserveToolCallIDs: true, - }) - ensureCodexOAuthInstructionsField(reqBody) - bodyModified = true - disablePatch() - } else { - codexResult = applyCodexOAuthTransform(reqBody, isCodexCLI, isCompactRequest) - } - if codexResult.Modified { - bodyModified = true - disablePatch() - } - if codexResult.NormalizedModel != "" { - upstreamModel = codexResult.NormalizedModel - } - if codexResult.PromptCacheKey != "" { - promptCacheKey = codexResult.PromptCacheKey - } - } - - // Handle max_output_tokens based on platform and account type - if !isCodexCLI { - if maxOutputTokens, hasMaxOutputTokens := reqBody["max_output_tokens"]; hasMaxOutputTokens { - switch account.Platform { - case PlatformOpenAI: - // For OpenAI API Key, remove max_output_tokens (not supported) - // For OpenAI OAuth (Responses API), keep it (supported) - if account.Type == AccountTypeAPIKey { - delete(reqBody, "max_output_tokens") - bodyModified = true - markPatchDelete("max_output_tokens") - } - case PlatformAnthropic: - // For Anthropic (Claude), convert to max_tokens - delete(reqBody, "max_output_tokens") - markPatchDelete("max_output_tokens") - if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { - reqBody["max_tokens"] = maxOutputTokens - disablePatch() - } - bodyModified = true - case PlatformGemini: - // For Gemini, remove (will be handled by Gemini-specific transform) - delete(reqBody, "max_output_tokens") - bodyModified = true - markPatchDelete("max_output_tokens") - default: - // For unknown platforms, remove to be safe - delete(reqBody, "max_output_tokens") - bodyModified = true - markPatchDelete("max_output_tokens") - } - } - - // Also handle max_completion_tokens (similar logic) - if _, hasMaxCompletionTokens := reqBody["max_completion_tokens"]; hasMaxCompletionTokens { - if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { - delete(reqBody, "max_completion_tokens") - bodyModified = true - markPatchDelete("max_completion_tokens") - } - } - - // Remove unsupported fields (not supported by upstream OpenAI API) - unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"} - for _, unsupportedField := range unsupportedFields { - if _, has := reqBody[unsupportedField]; has { - delete(reqBody, unsupportedField) - bodyModified = true - markPatchDelete(unsupportedField) - } - } - } - - // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。 - // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。 - if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { - if _, has := reqBody["previous_response_id"]; has { - delete(reqBody, "previous_response_id") - bodyModified = true - markPatchDelete("previous_response_id") - } - } - - if sanitizeEmptyBase64InputImagesInOpenAIRequestBodyMap(reqBody) { - bodyModified = true - disablePatch() - } - - // Apply OpenAI fast policy (参照 Claude BetaPolicy 的 fast-mode 过滤): - // 针对 body 的 service_tier 字段("priority" 即 fast,"flex"),按策略 - // 执行 filter(删除字段)或 block(拒绝请求)。对 gpt-5.5 等模型屏蔽 - // fast 时在此生效。 - // - // 注意: - // 1. 此处统一使用 upstreamModel(已经过 GetMappedModel + - // normalizeOpenAIModelForUpstream + Codex OAuth normalize),与 - // chat-completions / messages 入口保持一致,避免不同入口因为模型 - // 维度不同而出现 whitelist 命中差异。 - // 2. action=pass 时也要把 raw "fast" 归一化为 "priority" 写回 body, - // 否则 native /responses 入口透传 "fast" 给上游会被拒。chat- - // completions 入口由 normalizeResponsesBodyServiceTier 完成同一 - // 行为,这里手工实现等效逻辑。 - if rawTier, ok := reqBody["service_tier"].(string); ok { - if normTier := normalizedOpenAIServiceTierValue(rawTier); normTier != "" { - action, errMsg := s.evaluateOpenAIFastPolicy(ctx, account, upstreamModel, normTier) - switch action { - case BetaPolicyActionBlock: - msg := errMsg - if msg == "" { - msg = fmt.Sprintf("openai service_tier=%s is not allowed for model %s", normTier, upstreamModel) - } - blocked := &OpenAIFastBlockedError{Message: msg} - writeOpenAIFastPolicyBlockedResponse(c, blocked) - return nil, blocked - case BetaPolicyActionFilter: - delete(reqBody, "service_tier") - bodyModified = true - disablePatch() - default: - // pass:若客户端传的是别名 "fast",归一化为 "priority" - // 后写回 body,确保上游收到的是其能识别的规范值。 - if normTier != rawTier { - reqBody["service_tier"] = normTier - bodyModified = true - markPatchSet("service_tier", normTier) - } - } - } - } - - if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) && !imageGenerationAllowed { - setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "type": "permission_error", - "message": ImageGenerationPermissionMessage(), - }, - }) - return nil, errors.New("image generation disabled for group") - } - imageBillingModel := "" - imageSizeTier := "" - if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) { - var imageCfgErr error - imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel) - if imageCfgErr != nil { - setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": imageCfgErr.Error(), - "param": "size", - }, - }) - return nil, imageCfgErr - } - } - - // Re-serialize body only if modified - if bodyModified { - serializedByPatch := false - if !patchDisabled && patchHasOp { - var patchErr error - if patchDelete { - body, patchErr = sjson.DeleteBytes(body, patchPath) - } else { - body, patchErr = sjson.SetBytes(body, patchPath, patchValue) - } - if patchErr == nil { - serializedByPatch = true - } - } - if !serializedByPatch { - var marshalErr error - body, marshalErr = json.Marshal(reqBody) - if marshalErr != nil { - return nil, fmt.Errorf("serialize request body: %w", marshalErr) - } - } - } - - // Get access token - token, _, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - - // Capture upstream request body for ops retry of this attempt. - setOpsUpstreamRequestBody(c, body) - - // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。 - if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { - wsReqBody := reqBody - if len(reqBody) > 0 { - wsReqBody = make(map[string]any, len(reqBody)) - for k, v := range reqBody { - wsReqBody[k] = v - } - } - _, hasPreviousResponseID := wsReqBody["previous_response_id"] - logOpenAIWSModeDebug( - "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", - account.ID, - account.Type, - upstreamModel, - reqStream, - hasPreviousResponseID, - ) - maxAttempts := openAIWSReconnectRetryLimit + 1 - wsAttempts := 0 - var wsResult *OpenAIForwardResult - var wsErr error - wsLastFailureReason := "" - wsPrevResponseRecoveryTried := false - wsInvalidEncryptedContentRecoveryTried := false - recoverPrevResponseNotFound := func(attempt int) bool { - if wsPrevResponseRecoveryTried { - return false - } - previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") - if previousResponseID == "" { - logOpenAIWSModeInfo( - "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false", - account.ID, - attempt, - ) - return false - } - if HasFunctionCallOutput(wsReqBody) { - logOpenAIWSModeInfo( - "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true", - account.ID, - attempt, - ) - return false - } - delete(wsReqBody, "previous_response_id") - wsPrevResponseRecoveryTried = true - logOpenAIWSModeInfo( - "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s", - account.ID, - attempt, - truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), - normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), - ) - return true - } - recoverInvalidEncryptedContent := func(attempt int) bool { - if wsInvalidEncryptedContentRecoveryTried { - return false - } - removedReasoningItems := trimOpenAIEncryptedReasoningItems(wsReqBody) - if !removedReasoningItems { - logOpenAIWSModeInfo( - "reconnect_invalid_encrypted_content_recovery_skip account_id=%d attempt=%d reason=missing_encrypted_reasoning_items", - account.ID, - attempt, - ) - return false - } - previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") - hasFunctionCallOutput := HasFunctionCallOutput(wsReqBody) - if previousResponseID != "" && !hasFunctionCallOutput { - delete(wsReqBody, "previous_response_id") - } - wsInvalidEncryptedContentRecoveryTried = true - logOpenAIWSModeInfo( - "reconnect_invalid_encrypted_content_recovery account_id=%d attempt=%d action=drop_encrypted_reasoning_items retry=1 previous_response_id_present=%v previous_response_id=%s previous_response_id_kind=%s has_function_call_output=%v dropped_previous_response_id=%v", - account.ID, - attempt, - previousResponseID != "", - truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), - normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), - hasFunctionCallOutput, - previousResponseID != "" && !hasFunctionCallOutput, - ) - return true - } - retryBudget := s.openAIWSRetryTotalBudget() - retryStartedAt := time.Now() - wsRetryLoop: - for attempt := 1; attempt <= maxAttempts; attempt++ { - wsAttempts = attempt - wsResult, wsErr = s.forwardOpenAIWSV2( - ctx, - c, - account, - wsReqBody, - token, - wsDecision, - isCodexCLI, - reqStream, - originalModel, - upstreamModel, - startTime, - attempt, - wsLastFailureReason, - ) - if wsErr == nil { - break - } - if c != nil && c.Writer != nil && c.Writer.Written() { - break - } - - reason, retryable := classifyOpenAIWSReconnectReason(wsErr) - if reason != "" { - wsLastFailureReason = reason - } - // previous_response_not_found 说明续链锚点不可用: - // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。 - if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { - continue - } - if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) { - continue - } - if retryable && attempt < maxAttempts { - backoff := s.openAIWSRetryBackoff(attempt) - if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { - s.recordOpenAIWSRetryExhausted() - logOpenAIWSModeInfo( - "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d", - account.ID, - attempt, - openAIWSReconnectRetryLimit, - normalizeOpenAIWSLogValue(reason), - time.Since(retryStartedAt).Milliseconds(), - retryBudget.Milliseconds(), - ) - break - } - s.recordOpenAIWSRetryAttempt(backoff) - logOpenAIWSModeInfo( - "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d", - account.ID, - attempt, - openAIWSReconnectRetryLimit, - normalizeOpenAIWSLogValue(reason), - backoff.Milliseconds(), - ) - if backoff > 0 { - timer := time.NewTimer(backoff) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err()) - break wsRetryLoop - case <-timer.C: - } - } - continue - } - if retryable { - s.recordOpenAIWSRetryExhausted() - logOpenAIWSModeInfo( - "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s", - account.ID, - attempt, - openAIWSReconnectRetryLimit, - normalizeOpenAIWSLogValue(reason), - ) - } else if reason != "" { - s.recordOpenAIWSNonRetryableFastFallback() - logOpenAIWSModeInfo( - "reconnect_stop account_id=%d attempt=%d reason=%s", - account.ID, - attempt, - normalizeOpenAIWSLogValue(reason), - ) - } - break - } - if wsErr == nil { - firstTokenMs := int64(0) - hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil - if hasFirstTokenMs { - firstTokenMs = int64(*wsResult.FirstTokenMs) - } - requestID := "" - if wsResult != nil { - requestID = strings.TrimSpace(wsResult.RequestID) - } - logOpenAIWSModeDebug( - "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d", - account.ID, - requestID, - reqStream, - hasFirstTokenMs, - firstTokenMs, - wsAttempts, - ) - wsResult.UpstreamModel = upstreamModel - if wsResult.ImageCount > 0 { - wsResult.ImageSize = imageSizeTier - wsResult.BillingModel = imageBillingModel - } - return wsResult, nil - } - s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) - return nil, wsErr - } - - httpInvalidEncryptedContentRetryTried := false - for { - // Build upstream request - upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) - upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) - releaseUpstreamCtx() - if err != nil { - return nil, err - } - - // Get proxy URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // Send request - upstreamStart := time.Now() - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) - if err != nil { - // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - return nil, fmt.Errorf("upstream request failed: %s", safeErr) - } - - // Handle error response - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - resp.Body = io.NopCloser(bytes.NewReader(respBody)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamCode := extractUpstreamErrorCode(respBody) - if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" { - if trimOpenAIEncryptedReasoningItems(reqBody) { - body, err = json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err) - } - setOpsUpstreamRequestBody(c, body) - httpInvalidEncryptedContentRetryTried = true - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name) - continue - } - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name) - } - if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - - s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: respBody, - RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), - } - } - return s.handleErrorResponse(ctx, resp, c, account, body) - } - defer func() { _ = resp.Body.Close() }() - - // Handle normal response - var usage *OpenAIUsage - var firstTokenMs *int - imageCount := 0 - if reqStream { - streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) - if err != nil { - return nil, err - } - usage = streamResult.usage - firstTokenMs = streamResult.firstTokenMs - imageCount = streamResult.imageCount - } else { - nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) - if err != nil { - return nil, err - } - usage = nonStreamResult.usage - imageCount = nonStreamResult.imageCount - } - - // Extract and save Codex usage snapshot from response headers (for OAuth accounts) - if account.Type == AccountTypeOAuth { - if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { - s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) - } - } - - if usage == nil { - usage = &OpenAIUsage{} - } - - reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) - serviceTier := extractOpenAIServiceTier(reqBody) - - forwardResult := &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: originalModel, - UpstreamModel: upstreamModel, - ServiceTier: serviceTier, - ReasoningEffort: reasoningEffort, - Stream: reqStream, - OpenAIWSMode: false, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - } - if imageCount > 0 { - forwardResult.ImageCount = imageCount - forwardResult.ImageSize = imageSizeTier - forwardResult.BillingModel = imageBillingModel - } - return forwardResult, nil - } -} - -func (s *OpenAIGatewayService) forwardOpenAIPassthrough( - ctx context.Context, - c *gin.Context, - account *Account, - body []byte, - reqModel string, - reasoningEffort *string, - reqStream bool, - startTime time.Time, -) (*OpenAIForwardResult, error) { - upstreamPassthroughModel := "" - if isOpenAIResponsesCompactPath(c) { - compactMappedModel := resolveOpenAICompactForwardModel(account, reqModel) - if compactMappedModel != "" && compactMappedModel != reqModel { - nextBody, setErr := sjson.SetBytes(body, "model", compactMappedModel) - if setErr != nil { - return nil, fmt.Errorf("set compact passthrough model: %w", setErr) - } - body = nextBody - upstreamPassthroughModel = compactMappedModel - } - } - - if account != nil && account.Type == AccountTypeOAuth { - if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { - rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" - setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: http.StatusForbidden, - Passthrough: true, - Kind: "request_error", - Message: rejectMsg, - Detail: rejectReason, - }) - logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "type": "forbidden_error", - "message": rejectMsg, - }, - }) - return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) - } - - normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body, isOpenAIResponsesCompactPath(c)) - if err != nil { - return nil, err - } - if normalized { - body = normalizedBody - } - reqStream = gjson.GetBytes(body, "stream").Bool() - } - - sanitizedBody, sanitized, err := sanitizeEmptyBase64InputImagesInOpenAIBody(body) - if err != nil { - return nil, err - } - if sanitized { - body = sanitizedBody - } - - // Apply OpenAI fast policy to the passthrough body (filter/block by service_tier). - // 统一使用 upstream 视角的 model:透传路径下 body 已经过 compact 映射 + - // OAuth normalize,body 中的 model 字段即上游真正会看到的 slug。 - // 这样可以与 chat-completions / messages / native /responses 入口的 - // upstreamModel 保持一致,避免 whitelist 命中差异。当 body 中没有 - // model 字段时退回 reqModel。 - policyModel := strings.TrimSpace(gjson.GetBytes(body, "model").String()) - if policyModel == "" { - policyModel = reqModel - } - updatedBody, policyErr := s.applyOpenAIFastPolicyToBody(ctx, account, policyModel, body) - if policyErr != nil { - var blocked *OpenAIFastBlockedError - if errors.As(policyErr, &blocked) { - writeOpenAIFastPolicyBlockedResponse(c, blocked) - } - return nil, policyErr - } - body = updatedBody - - apiKey := getAPIKeyFromContext(c) - if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) && !GroupAllowsImageGeneration(apiKeyGroup(apiKey)) { - setOpsUpstreamError(c, http.StatusForbidden, ImageGenerationPermissionMessage(), "") - c.JSON(http.StatusForbidden, gin.H{ - "error": gin.H{ - "type": "permission_error", - "message": ImageGenerationPermissionMessage(), - }, - }) - return nil, errors.New("image generation disabled for group") - } - imageBillingModel := "" - imageSizeTier := "" - if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) { - var imageCfgErr error - imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel) - if imageCfgErr != nil { - setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": imageCfgErr.Error(), - "param": "size", - }, - }) - return nil, imageCfgErr - } - } - - logger.LegacyPrintf("service.openai_gateway", - "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", - account.ID, - account.Name, - account.Type, - reqModel, - reqStream, - ) - if reqStream && c != nil && c.Request != nil { - if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 { - streamWarnLogger := logger.FromContext(ctx).With( - zap.String("component", "service.openai_gateway"), - zap.Int64("account_id", account.ID), - zap.Strings("timeout_headers", timeoutHeaders), - ) - if s.isOpenAIPassthroughTimeoutHeadersAllowed() { - streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流") - } else { - streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险") - } - } - } - - // Get access token - token, _, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - - upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) - upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(upstreamCtx, c, account, body, token) - releaseUpstreamCtx() - if err != nil { - return nil, err - } - - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - setOpsUpstreamRequestBody(c, body) - if c != nil { - c.Set("openai_passthrough", true) - } - - upstreamStart := time.Now() - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) - if err != nil { - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - setOpsUpstreamError(c, 0, safeErr, "") - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Passthrough: true, - Kind: "request_error", - Message: safeErr, - }) - c.JSON(http.StatusBadGateway, gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": "Upstream request failed", - }, - }) - return nil, fmt.Errorf("upstream request failed: %s", safeErr) - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode >= 400 { - // 透传模式默认保持原样代理;但 429/529 属于网关必须兜底的 - // 上游容量类错误,应先触发多账号 failover 以维持基础 SLA。 - if shouldFailoverOpenAIPassthroughResponse(resp.StatusCode) { - return nil, s.handleFailoverErrorResponsePassthrough(ctx, resp, c, account, body) - } - return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) - } - - var usage *OpenAIUsage - var firstTokenMs *int - imageCount := 0 - if reqStream { - result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel) - if err != nil { - return nil, err - } - usage = result.usage - firstTokenMs = result.firstTokenMs - imageCount = result.imageCount - } else { - result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) - if err != nil { - return nil, err - } - usage = result.usage - imageCount = result.imageCount - } - - if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { - s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) - } - - if usage == nil { - usage = &OpenAIUsage{} - } - - forwardResult := &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: *usage, - Model: reqModel, - UpstreamModel: upstreamPassthroughModel, - ServiceTier: extractOpenAIServiceTierFromBody(body), - ReasoningEffort: reasoningEffort, - Stream: reqStream, - OpenAIWSMode: false, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - } - if imageCount > 0 { - forwardResult.ImageCount = imageCount - forwardResult.ImageSize = imageSizeTier - forwardResult.BillingModel = imageBillingModel - } - return forwardResult, nil -} - -func logOpenAIPassthroughInstructionsRejected( - ctx context.Context, - c *gin.Context, - account *Account, - reqModel string, - rejectReason string, - body []byte, -) { - if ctx == nil { - ctx = context.Background() - } - accountID := int64(0) - accountName := "" - accountType := "" - if account != nil { - accountID = account.ID - accountName = strings.TrimSpace(account.Name) - accountType = strings.TrimSpace(string(account.Type)) - } - fields := []zap.Field{ - zap.String("component", "service.openai_gateway"), - zap.Int64("account_id", accountID), - zap.String("account_name", accountName), - zap.String("account_type", accountType), - zap.String("request_model", strings.TrimSpace(reqModel)), - zap.String("reject_reason", strings.TrimSpace(rejectReason)), - } - fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) - logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions") -} - -func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( - ctx context.Context, - c *gin.Context, - account *Account, - body []byte, - token string, -) (*http.Request, error) { - targetURL := openaiPlatformAPIURL - switch account.Type { - case AccountTypeOAuth: - targetURL = chatgptCodexURL - case AccountTypeAPIKey: - baseURL := account.GetOpenAIBaseURL() - if baseURL != "" { - validatedURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return nil, err - } - targetURL = buildOpenAIResponsesURL(validatedURL) - } - } - targetURL = appendOpenAIResponsesRequestPathSuffix(targetURL, openAIResponsesRequestPathSuffix(c)) - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) - if err != nil { - return nil, err - } - - // 透传客户端请求头(安全白名单)。 - allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() - if c != nil && c.Request != nil { - for key, values := range c.Request.Header { - lower := strings.ToLower(strings.TrimSpace(key)) - if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) { - continue - } - for _, v := range values { - req.Header.Add(key, v) - } - } - } - - // 覆盖入站鉴权残留,并注入上游认证 - req.Header.Del("authorization") - req.Header.Del("x-api-key") - req.Header.Del("x-goog-api-key") - req.Header.Set("authorization", "Bearer "+token) - - // OAuth 透传到 ChatGPT internal API 时补齐必要头。 - if account.Type == AccountTypeOAuth { - promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) - req.Host = "chatgpt.com" - if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { - req.Header.Set("chatgpt-account-id", chatgptAccountID) - } - apiKeyID := getAPIKeyIDFromContext(c) - // 先保存客户端原始值,再做 compact 补充,避免后续统一隔离时读到已处理的值。 - clientSessionID := strings.TrimSpace(req.Header.Get("session_id")) - clientConversationID := strings.TrimSpace(req.Header.Get("conversation_id")) - if isOpenAIResponsesCompactPath(c) { - req.Header.Set("accept", "application/json") - if req.Header.Get("version") == "" { - req.Header.Set("version", codexCLIVersion) - } - if clientSessionID == "" { - clientSessionID = resolveOpenAICompactSessionID(c) - } - } else if req.Header.Get("accept") == "" { - req.Header.Set("accept", "text/event-stream") - } - if req.Header.Get("OpenAI-Beta") == "" { - req.Header.Set("OpenAI-Beta", "responses=experimental") - } - if req.Header.Get("originator") == "" { - req.Header.Set("originator", "codex_cli_rs") - } - // 用隔离后的 session 标识符覆盖客户端透传值,防止跨用户会话碰撞。 - if clientSessionID == "" { - clientSessionID = promptCacheKey - } - if clientConversationID == "" { - clientConversationID = promptCacheKey - } - if clientSessionID != "" { - req.Header.Set("session_id", isolateOpenAISessionID(apiKeyID, clientSessionID)) - } - if clientConversationID != "" { - req.Header.Set("conversation_id", isolateOpenAISessionID(apiKeyID, clientConversationID)) - } - } - - // 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。 - customUA := account.GetOpenAIUserAgent() - if customUA != "" { - req.Header.Set("user-agent", customUA) - } - if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { - req.Header.Set("user-agent", codexCLIUserAgent) - } - // OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。 - if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) { - req.Header.Set("user-agent", codexCLIUserAgent) - } - - if req.Header.Get("content-type") == "" { - req.Header.Set("content-type", "application/json") - } - - return req, nil -} - -func shouldFailoverOpenAIPassthroughResponse(statusCode int) bool { - switch statusCode { - case http.StatusTooManyRequests, 529: - return true - default: - return false - } -} - -func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough( - ctx context.Context, - resp *http.Response, - c *gin.Context, - account *Account, - requestBody []byte, -) error { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(body), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - if s.rateLimitService != nil { - _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Passthrough: true, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - UpstreamResponseBody: upstreamDetail, - }) - return &UpstreamFailoverError{ - StatusCode: resp.StatusCode, - ResponseBody: body, - ResponseHeaders: resp.Header.Clone(), - } -} - -func (s *OpenAIGatewayService) handleErrorResponsePassthrough( - ctx context.Context, - resp *http.Response, - c *gin.Context, - account *Account, - requestBody []byte, -) error { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - upstreamDetail = truncateString(string(body), maxBytes) - } - setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) - logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) - if s.rateLimitService != nil { - // Passthrough mode preserves the raw upstream error response, but runtime - // account state still needs to be updated so sticky routing can stop - // reusing a freshly rate-limited account. - _ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Passthrough: true, - Kind: "http_error", - Message: upstreamMsg, - Detail: upstreamDetail, - UpstreamResponseBody: upstreamDetail, - }) - - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - contentType := resp.Header.Get("Content-Type") - if contentType == "" { - contentType = "application/json" - } - c.Data(resp.StatusCode, contentType, body) - - if upstreamMsg == "" { - return fmt.Errorf("upstream error: %d", resp.StatusCode) - } - return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) -} - -func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { - if lowerKey == "" { - return false - } - if isOpenAIPassthroughTimeoutHeader(lowerKey) { - return allowTimeoutHeaders - } - return openaiPassthroughAllowedHeaders[lowerKey] -} - -func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { - switch lowerKey { - case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout": - return true - default: - return false - } -} - -func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool { - return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders -} - -func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { - if h == nil { - return nil - } - var matched []string - for key, values := range h { - lowerKey := strings.ToLower(strings.TrimSpace(key)) - if isOpenAIPassthroughTimeoutHeader(lowerKey) { - entry := lowerKey - if len(values) > 0 { - entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|")) - } - matched = append(matched, entry) - } - } - sort.Strings(matched) - return matched -} - -type openaiStreamingResultPassthrough struct { - usage *OpenAIUsage - firstTokenMs *int - imageCount int -} - -type openaiNonStreamingResultPassthrough struct { - *OpenAIUsage - usage *OpenAIUsage - imageCount int -} - -func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { - if localStarted { - return true - } - return c != nil && c.Writer != nil && c.Writer.Written() -} - -func openAIStreamEventIsPreamble(eventType string) bool { - switch strings.TrimSpace(eventType) { - case "response.created", "response.in_progress": - return true - default: - return false - } -} - -func openAIStreamDataStartsClientOutput(data, eventType string) bool { - trimmed := strings.TrimSpace(data) - if trimmed == "" { - return false - } - if strings.TrimSpace(eventType) == "response.failed" { - return false - } - return !openAIStreamEventIsPreamble(eventType) -} - -func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool { - code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String())) - if code == "" { - code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String())) - } - errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.type").String())) - if errType == "" { - errType = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.type").String())) - } - combined := strings.ToLower(strings.TrimSpace(message + " " + code + " " + errType)) - if combined == "" { - return true - } - nonRetryableMarkers := []string{ - "invalid_request", - "content_policy", - "policy", - "safety", - "high-risk cyber", - "not allowed", - "violat", - } - for _, marker := range nonRetryableMarkers { - if strings.Contains(combined, marker) { - return false - } - } - return true -} - -func (s *OpenAIGatewayService) newOpenAIStreamFailoverError( - c *gin.Context, - account *Account, - passthrough bool, - upstreamRequestID string, - payload []byte, - message string, -) *UpstreamFailoverError { - message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) - if message == "" { - message = "OpenAI stream disconnected before completion" - } - detail := "" - if len(payload) > 0 && s != nil && s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 - } - detail = truncateString(string(payload), maxBytes) - } - if c != nil { - setOpsUpstreamError(c, http.StatusBadGateway, message, detail) - event := OpsUpstreamErrorEvent{ - Platform: PlatformOpenAI, - UpstreamStatusCode: http.StatusBadGateway, - UpstreamRequestID: strings.TrimSpace(upstreamRequestID), - Passthrough: passthrough, - Kind: "failover", - Message: message, - Detail: detail, - } - if account != nil { - event.Platform = account.Platform - event.AccountID = account.ID - event.AccountName = account.Name - } - appendOpsUpstreamError(c, event) - } - body, _ := json.Marshal(gin.H{ - "error": gin.H{ - "type": "upstream_error", - "message": message, - }, - }) - return &UpstreamFailoverError{ - StatusCode: http.StatusBadGateway, - ResponseBody: body, - } -} - -func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( - ctx context.Context, - resp *http.Response, - c *gin.Context, - account *Account, - startTime time.Time, - originalModel string, - mappedModel string, -) (*openaiStreamingResultPassthrough, error) { - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) - - // SSE headers - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - if v := resp.Header.Get("x-request-id"); v != "" { - c.Header("x-request-id", v) - } - - w := c.Writer - flusher, ok := w.(http.Flusher) - if !ok { - return nil, errors.New("streaming not supported") - } - - usage := &OpenAIUsage{} - imageCounter := newOpenAIImageOutputCounter() - var firstTokenMs *int - clientDisconnected := false - sawDone := false - sawTerminalEvent := false - sawFailedEvent := false - failedMessage := "" - clientOutputStarted := false - upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) - pendingLines := make([]string, 0, 8) - writePendingLines := func() bool { - for _, pending := range pendingLines { - if _, err := fmt.Fprintln(w, pending); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) - return false - } - } - pendingLines = pendingLines[:0] - return true - } - - scanner := bufio.NewScanner(resp.Body) - maxLineSize := defaultMaxLineSize - if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { - maxLineSize = s.cfg.Gateway.MaxLineSize - } - scanBuf := getSSEScannerBuf64K() - scanner.Buffer(scanBuf[:0], maxLineSize) - defer putSSEScannerBuf64K(scanBuf) - - needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel) - resultWithUsage := func() *openaiStreamingResultPassthrough { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} - } - - for scanner.Scan() { - line := scanner.Text() - lineStartsClientOutput := false - forceFlushFailedEvent := false - if data, ok := extractOpenAISSEDataLine(line); ok { - dataBytes := []byte(data) - trimmedData := strings.TrimSpace(data) - if needModelReplace && strings.Contains(data, mappedModel) { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - if replacedData, replaced := extractOpenAISSEDataLine(line); replaced { - dataBytes = []byte(replacedData) - trimmedData = strings.TrimSpace(replacedData) - } - } - eventType := strings.TrimSpace(gjson.Get(trimmedData, "type").String()) - if eventType == "response.failed" { - failedMessage = extractOpenAISSEErrorMessage(dataBytes) - if !openAIStreamClientOutputStarted(c, clientOutputStarted) && openAIStreamFailedEventShouldFailover(dataBytes, failedMessage) { - return resultWithUsage(), - s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, dataBytes, failedMessage) - } - forceFlushFailedEvent = true - sawFailedEvent = true + retryBudget := s.openAIWSRetryTotalBudget() + retryStartedAt := time.Now() + wsRetryLoop: + for attempt := 1; attempt <= maxAttempts; attempt++ { + wsAttempts = attempt + wsResult, wsErr = s.forwardOpenAIWSV2( + ctx, + c, + account, + wsReqBody, + token, + wsDecision, + isCodexCLI, + reqStream, + originalModel, + upstreamModel, + startTime, + attempt, + wsLastFailureReason, + ) + if wsErr == nil { + break } - if trimmedData == "[DONE]" { - sawDone = true + if c != nil && c.Writer != nil && c.Writer.Written() { + break } - if openAIStreamEventIsTerminal(trimmedData) { - sawTerminalEvent = true + + reason, retryable := classifyOpenAIWSReconnectReason(wsErr) + if reason != "" { + wsLastFailureReason = reason } - imageCounter.AddSSEData(dataBytes) - lineStartsClientOutput = forceFlushFailedEvent || openAIStreamDataStartsClientOutput(trimmedData, eventType) - if firstTokenMs == nil && lineStartsClientOutput && trimmedData != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms + // previous_response_not_found 说明续链锚点不可用: + // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。 + if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { + continue } - s.parseSSEUsageBytes(dataBytes, usage) - } - - if !clientDisconnected { - if !clientOutputStarted && !lineStartsClientOutput { - pendingLines = append(pendingLines, line) + if reason == "invalid_encrypted_content" && recoverInvalidEncryptedContent(attempt) { continue } - if !clientOutputStarted && len(pendingLines) > 0 { - if !writePendingLines() { - continue + if retryable && attempt < maxAttempts { + backoff := s.openAIWSRetryBackoff(attempt) + if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + time.Since(retryStartedAt).Milliseconds(), + retryBudget.Milliseconds(), + ) + break + } + s.recordOpenAIWSRetryAttempt(backoff) + logOpenAIWSModeInfo( + "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + backoff.Milliseconds(), + ) + if backoff > 0 { + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err()) + break wsRetryLoop + case <-timer.C: + } } + continue } - if _, err := fmt.Fprintln(w, line); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) - } else { - clientOutputStarted = true - flusher.Flush() + if retryable { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + ) + } else if reason != "" { + s.recordOpenAIWSNonRetryableFastFallback() + logOpenAIWSModeInfo( + "reconnect_stop account_id=%d attempt=%d reason=%s", + account.ID, + attempt, + normalizeOpenAIWSLogValue(reason), + ) } + break } - } - if err := scanner.Err(); err != nil { - if sawTerminalEvent && !sawFailedEvent { - return resultWithUsage(), nil - } - if sawFailedEvent { - return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) - } - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return resultWithUsage(), fmt.Errorf("stream usage incomplete: %w", err) - } - if errors.Is(err, bufio.ErrTooLong) { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) - return resultWithUsage(), err - } - if !openAIStreamClientOutputStarted(c, clientOutputStarted) { - msg := "OpenAI stream disconnected before completion" - if errText := strings.TrimSpace(err.Error()); errText != "" { - msg += ": " + errText + if wsErr == nil { + firstTokenMs := int64(0) + hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil + if hasFirstTokenMs { + firstTokenMs = int64(*wsResult.FirstTokenMs) } - return resultWithUsage(), - s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, msg) - } - if clientDisconnected { - return resultWithUsage(), fmt.Errorf("stream usage incomplete after disconnect: %w", err) - } - logger.LegacyPrintf("service.openai_gateway", - "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", - account.ID, - upstreamRequestID, - err, - ) - return resultWithUsage(), fmt.Errorf("stream read error: %w", err) - } - if sawFailedEvent { - return resultWithUsage(), fmt.Errorf("upstream response failed: %s", failedMessage) - } - if !clientDisconnected && !sawDone && !sawTerminalEvent && ctx.Err() == nil { - logger.FromContext(ctx).With( - zap.String("component", "service.openai_gateway"), - zap.Int64("account_id", account.ID), - zap.String("upstream_request_id", upstreamRequestID), - ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") - if !openAIStreamClientOutputStarted(c, clientOutputStarted) { - return resultWithUsage(), - s.newOpenAIStreamFailoverError(c, account, true, upstreamRequestID, nil, "OpenAI stream ended before a terminal event") + requestID := "" + if wsResult != nil { + requestID = strings.TrimSpace(wsResult.RequestID) + } + logOpenAIWSModeDebug( + "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d", + account.ID, + requestID, + reqStream, + hasFirstTokenMs, + firstTokenMs, + wsAttempts, + ) + wsResult.UpstreamModel = upstreamModel + if wsResult.ImageCount > 0 { + wsResult.ImageSize = imageSizeTier + wsResult.BillingModel = imageBillingModel + } + return wsResult, nil } - return resultWithUsage(), errors.New("stream usage incomplete: missing terminal event") - } - - return resultWithUsage(), nil -} - -func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( - ctx context.Context, - resp *http.Response, - c *gin.Context, - originalModel string, - mappedModel string, -) (*openaiNonStreamingResultPassthrough, error) { - body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) - if err != nil { - return nil, err - } - - // Detect SSE responses from upstream and convert to JSON. - // Some upstreams (e.g. other sub2api instances) may return SSE even when - // stream=false was requested. Without this conversion the client would - // receive raw SSE text or a terminal event with empty output. - if isEventStreamResponse(resp.Header) { - return s.handlePassthroughSSEToJSON(resp, c, body, originalModel, mappedModel) + s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) + return nil, wsErr } - usage := &OpenAIUsage{} - usageParsed := false - if len(body) > 0 { - if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok { - *usage = parsedUsage - usageParsed = true + httpInvalidEncryptedContentRetryTried := false + for { + // Build upstream request + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) + upstreamReq, err := s.buildUpstreamRequest(upstreamCtx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) + releaseUpstreamCtx() + if err != nil { + return nil, err } - } - if !usageParsed { - // 兜底:尝试从 SSE 文本中解析 usage - usage = s.parseSSEUsageFromBody(string(body)) - } - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } - contentType := resp.Header.Get("Content-Type") - if contentType == "" { - contentType = "application/json" - } - if originalModel != "" && mappedModel != "" && originalModel != mappedModel { - body = s.replaceModelInResponseBody(body, mappedModel, originalModel) - } - c.Data(resp.StatusCode, contentType, body) - return &openaiNonStreamingResultPassthrough{ - OpenAIUsage: usage, - usage: usage, - imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), - }, nil -} + // Send request + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } -// handlePassthroughSSEToJSON converts an SSE response body into a JSON -// response for the passthrough path. It mirrors handleSSEToJSON while -// preserving passthrough payloads, except compact-only model remapping may -// rewrite model fields back to the original requested model. -func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel string, mappedModel string) (*openaiNonStreamingResultPassthrough, error) { - bodyText := string(body) - finalResponse, ok := extractCodexFinalResponse(bodyText) + // Handle error response + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) - usage := &OpenAIUsage{} - if ok { - if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { - *usage = parsedUsage - } - // When the terminal event has an empty output array, reconstruct - // output from accumulated delta events so the client gets full content. - if len(gjson.GetBytes(finalResponse, "output").Array()) == 0 { - if outputJSON, reconstructed := reconstructResponseOutputFromSSE(bodyText); reconstructed { - if patched, err := sjson.SetRawBytes(finalResponse, "output", outputJSON); err == nil { - finalResponse = patched + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamCode := extractUpstreamErrorCode(respBody) + if !httpInvalidEncryptedContentRetryTried && resp.StatusCode == http.StatusBadRequest && upstreamCode == "invalid_encrypted_content" { + if trimOpenAIEncryptedReasoningItems(reqBody) { + body, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("serialize invalid_encrypted_content retry body: %w", err) + } + setOpsUpstreamRequestBody(c, body) + httpInvalidEncryptedContentRetryTried = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Retrying non-WSv2 request once after invalid_encrypted_content (account: %s)", account.Name) + continue } + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Skip non-WSv2 invalid_encrypted_content retry because encrypted reasoning items are missing (account: %s)", account.Name) } - } - body = finalResponse - if originalModel != "" && mappedModel != "" && originalModel != mappedModel { - body = s.replaceModelInResponseBody(body, mappedModel, originalModel) - } - // Correct tool calls in final response - body = s.correctToolCallsInResponseBody(body) - } else { - terminalType, terminalPayload, terminalOK := extractOpenAISSETerminalEvent(bodyText) - if terminalOK && terminalType == "response.failed" { - msg := extractOpenAISSEErrorMessage(terminalPayload) - if msg == "" { - msg = "Upstream compact response failed" + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && (isPoolModeRetryableStatus(resp.StatusCode) || isOpenAITransientProcessingError(resp.StatusCode, upstreamMsg, respBody)), + } } - return nil, s.writeOpenAINonStreamingProtocolError(resp, c, msg) - } - usage = s.parseSSEUsageFromBody(bodyText) - if originalModel != "" && mappedModel != "" && originalModel != mappedModel { - bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) + return s.handleErrorResponse(ctx, resp, c, account, body) } - body = []byte(bodyText) - } - - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + defer func() { _ = resp.Body.Close() }() - contentType := "application/json; charset=utf-8" - if !ok { - contentType = resp.Header.Get("Content-Type") - if contentType == "" { - contentType = "text/event-stream" + // Handle normal response + var usage *OpenAIUsage + var firstTokenMs *int + imageCount := 0 + if reqStream { + streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + imageCount = streamResult.imageCount + } else { + nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) + if err != nil { + return nil, err + } + usage = nonStreamResult.usage + imageCount = nonStreamResult.imageCount } - } - c.Data(resp.StatusCode, contentType, body) - - return &openaiNonStreamingResultPassthrough{ - OpenAIUsage: usage, - usage: usage, - imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), - }, nil -} -func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { - if dst == nil || src == nil { - return - } - if filter != nil { - responseheaders.WriteFilteredHeaders(dst, src, filter) - } else { - // 兜底:尽量保留最基础的 content-type - if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { - dst.Set("Content-Type", v) - } - } - // 透传模式强制放行 x-codex-* 响应头(若上游返回)。 - // 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应, - // 这里用 EqualFold 做一次大小写不敏感的查找。 - getCaseInsensitiveValues := func(h http.Header, want string) []string { - if h == nil { - return nil - } - for k, vals := range h { - if strings.EqualFold(k, want) { - return vals + // Extract and save Codex usage snapshot from response headers (for OAuth accounts) + if account.Type == AccountTypeOAuth { + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) } } - return nil - } - for _, rawKey := range []string{ - "x-codex-primary-used-percent", - "x-codex-primary-reset-after-seconds", - "x-codex-primary-window-minutes", - "x-codex-secondary-used-percent", - "x-codex-secondary-reset-after-seconds", - "x-codex-secondary-window-minutes", - "x-codex-primary-over-secondary-limit-percent", - } { - vals := getCaseInsensitiveValues(src, rawKey) - if len(vals) == 0 { - continue + if usage == nil { + usage = &OpenAIUsage{} + } + + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) + serviceTier := extractOpenAIServiceTier(reqBody) + + forwardResult := &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + UpstreamModel: upstreamModel, + ServiceTier: serviceTier, + ReasoningEffort: reasoningEffort, + Stream: reqStream, + OpenAIWSMode: false, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, } - key := http.CanonicalHeaderKey(rawKey) - dst.Del(key) - for _, v := range vals { - dst.Add(key, v) + if imageCount > 0 { + forwardResult.ImageCount = imageCount + forwardResult.ImageSize = imageSizeTier + forwardResult.BillingModel = imageBillingModel } + return forwardResult, nil } } @@ -4625,45 +2489,6 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { s.parseSSEUsageBytes([]byte(data), usage) } -func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { - if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { - return - } - // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。 - if len(data) < 72 { - return - } - eventType := gjson.GetBytes(data, "type").String() - if eventType != "response.completed" && eventType != "response.done" && - eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" { - return - } - - usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) - usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) - usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) - usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int()) -} - -func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { - if len(body) == 0 || !gjson.ValidBytes(body) { - return OpenAIUsage{}, false - } - values := gjson.GetManyBytes( - body, - "usage.input_tokens", - "usage.output_tokens", - "usage.input_tokens_details.cached_tokens", - "usage.output_tokens_details.image_tokens", - ) - return OpenAIUsage{ - InputTokens: int(values[0].Int()), - OutputTokens: int(values[1].Int()), - CacheReadInputTokens: int(values[2].Int()), - ImageOutputTokens: int(values[3].Int()), - }, true -} - func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*openaiNonStreamingResult, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { @@ -4915,14 +2740,6 @@ func extractImageGenerationOutputFromSSEData(data []byte, seen map[string]struct return json.RawMessage(item.Raw), true } -func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { - usage := &OpenAIUsage{} - forEachOpenAISSEDataPayload(body, func(data []byte) { - s.parseSSEUsageBytes(data, usage) - }) - return usage -} - func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { lines := strings.Split(body, "\n") for i, line := range lines { @@ -5183,22 +3000,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel return body } -// OpenAIRecordUsageInput input for recording usage -type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - InboundEndpoint string - UpstreamEndpoint string - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - RequestPayloadHash string - APIKeyService APIKeyQuotaUpdater - ChannelUsageFields -} - // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { if input == nil { @@ -5416,362 +3217,6 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec return nil } -func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( - ctx context.Context, - result *OpenAIForwardResult, - apiKey *APIKey, - billingModels []string, - multiplier float64, - imageMultiplier float64, - tokens UsageTokens, - serviceTier string, -) (*CostBreakdown, error) { - billingModel := firstUsageBillingModel(billingModels) - if result != nil && result.ImageCount > 0 { - return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil - } - if len(billingModels) == 0 || billingModel == "" { - return nil, errors.New("openai usage billing model is empty") - } - var lastErr error - for _, candidate := range billingModels { - candidate = strings.TrimSpace(candidate) - if candidate == "" { - continue - } - cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier) - if err == nil { - return cost, nil - } - lastErr = err - } - if lastErr == nil { - lastErr = errors.New("no non-empty billing model candidates") - } - return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr) -} - -func isUsagePricingUnavailableError(err error) bool { - if err == nil { - return false - } - if errors.Is(err, ErrModelPricingUnavailable) { - return true - } - msg := strings.ToLower(err.Error()) - return strings.Contains(msg, "no pricing available") || strings.Contains(msg, "pricing not found") -} - -func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost( - ctx context.Context, - apiKey *APIKey, - billingModel string, - multiplier float64, - tokens UsageTokens, - serviceTier string, -) (*CostBreakdown, error) { - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - return s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - RateMultiplier: multiplier, - ServiceTier: serviceTier, - Resolver: s.resolver, - }) - } - return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) -} - -func (s *OpenAIGatewayService) calculateOpenAIImageCost( - ctx context.Context, - billingModel string, - apiKey *APIKey, - result *OpenAIForwardResult, - multiplier float64, -) *CostBreakdown { - if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil && - (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) { - gid := apiKey.Group.ID - cost, err := s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - RequestCount: result.ImageCount, - SizeTier: result.ImageSize, - RateMultiplier: multiplier, - Resolver: s.resolver, - Resolved: resolved, - }) - if err == nil { - return cost - } - logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err) - } - - var groupConfig *ImagePriceConfig - if apiKey != nil && apiKey.Group != nil { - groupConfig = &ImagePriceConfig{ - Price1K: apiKey.Group.ImagePrice1K, - Price2K: apiKey.Group.ImagePrice2K, - Price4K: apiKey.Group.ImagePrice4K, - } - } - return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) -} - -func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { - if s.resolver == nil || apiKey == nil || apiKey.Group == nil { - return nil - } - gid := apiKey.Group.ID - resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) - if resolved.Source == PricingSourceChannel { - return resolved - } - return nil -} - -// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. -// Exported for use in ratelimit_service when handling OpenAI 429 responses. -func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { - snapshot := &OpenAICodexUsageSnapshot{} - hasData := false - - // Helper to parse float64 from header - parseFloat := func(key string) *float64 { - if v := headers.Get(key); v != "" { - if f, err := strconv.ParseFloat(v, 64); err == nil { - return &f - } - } - return nil - } - - // Helper to parse int from header - parseInt := func(key string) *int { - if v := headers.Get(key); v != "" { - if i, err := strconv.Atoi(v); err == nil { - return &i - } - } - return nil - } - - // Primary (weekly) limits - if v := parseFloat("x-codex-primary-used-percent"); v != nil { - snapshot.PrimaryUsedPercent = v - hasData = true - } - if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil { - snapshot.PrimaryResetAfterSeconds = v - hasData = true - } - if v := parseInt("x-codex-primary-window-minutes"); v != nil { - snapshot.PrimaryWindowMinutes = v - hasData = true - } - - // Secondary (5h) limits - if v := parseFloat("x-codex-secondary-used-percent"); v != nil { - snapshot.SecondaryUsedPercent = v - hasData = true - } - if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil { - snapshot.SecondaryResetAfterSeconds = v - hasData = true - } - if v := parseInt("x-codex-secondary-window-minutes"); v != nil { - snapshot.SecondaryWindowMinutes = v - hasData = true - } - - // Overflow ratio - if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil { - snapshot.PrimaryOverSecondaryPercent = v - hasData = true - } - - if !hasData { - return nil - } - - snapshot.UpdatedAt = time.Now().Format(time.RFC3339) - return snapshot -} - -func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time { - if snapshot == nil { - return fallback - } - if snapshot.UpdatedAt == "" { - return fallback - } - base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt) - if err != nil { - return fallback - } - return base -} - -func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string { - if resetAfterSeconds == nil { - return nil - } - sec := *resetAfterSeconds - if sec < 0 { - sec = 0 - } - resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339) - return &resetAt -} - -func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any { - if snapshot == nil { - return nil - } - - baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) - updates := make(map[string]any) - - // 保存原始 primary/secondary 字段,便于排查问题 - if snapshot.PrimaryUsedPercent != nil { - updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent - } - if snapshot.PrimaryResetAfterSeconds != nil { - updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds - } - if snapshot.PrimaryWindowMinutes != nil { - updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes - } - if snapshot.SecondaryUsedPercent != nil { - updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent - } - if snapshot.SecondaryResetAfterSeconds != nil { - updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds - } - if snapshot.SecondaryWindowMinutes != nil { - updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes - } - if snapshot.PrimaryOverSecondaryPercent != nil { - updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent - } - updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339) - - // 归一化到 5h/7d 规范字段 - if normalized := snapshot.Normalize(); normalized != nil { - if normalized.Used5hPercent != nil { - updates["codex_5h_used_percent"] = *normalized.Used5hPercent - } - if normalized.Reset5hSeconds != nil { - updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds - } - if normalized.Window5hMinutes != nil { - updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes - } - if normalized.Used7dPercent != nil { - updates["codex_7d_used_percent"] = *normalized.Used7dPercent - } - if normalized.Reset7dSeconds != nil { - updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds - } - if normalized.Window7dMinutes != nil { - updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes - } - if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil { - updates["codex_5h_reset_at"] = *reset5hAt - } - if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil { - updates["codex_7d_reset_at"] = *reset7dAt - } - } - - return updates -} - -// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field -func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { - if snapshot == nil { - return - } - if s == nil || s.accountRepo == nil { - return - } - - now := time.Now() - updates := buildCodexUsageExtraUpdates(snapshot, now) - if len(updates) == 0 { - return - } - if !s.getCodexSnapshotThrottle().Allow(accountID, now) { - return - } - - go func() { - updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) - }() -} - -func (s *OpenAIGatewayService) UpdateCodexUsageSnapshotFromHeaders(ctx context.Context, accountID int64, headers http.Header) { - if accountID <= 0 || headers == nil { - return - } - if snapshot := ParseCodexRateLimitHeaders(headers); snapshot != nil { - s.updateCodexUsageSnapshot(ctx, accountID, snapshot) - } -} - -func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { - if reqBody == nil { - return "", false - } - - // Primary: reasoning.effort - if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { - if effort, ok := reasoning["effort"].(string); ok { - return normalizeOpenAIReasoningEffort(effort), true - } - } - - // Fallback: some clients may use a flat field. - if effort, ok := reqBody["reasoning_effort"].(string); ok { - return normalizeOpenAIReasoningEffort(effort), true - } - - return "", false -} - -func deriveOpenAIReasoningEffortFromModel(model string) string { - if strings.TrimSpace(model) == "" { - return "" - } - - modelID := strings.TrimSpace(model) - if strings.Contains(modelID, "/") { - parts := strings.Split(modelID, "/") - modelID = parts[len(parts)-1] - } - - parts := strings.FieldsFunc(strings.ToLower(modelID), func(r rune) bool { - switch r { - case '-', '_', ' ': - return true - default: - return false - } - }) - if len(parts) == 0 { - return "" - } - - return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) -} - func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { if len(body) == 0 { return "", false, "" @@ -5864,26 +3309,6 @@ func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byt return "" } -func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { - reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) - if reasoningEffort == "" { - reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) - } - if reasoningEffort != "" { - normalized := normalizeOpenAIReasoningEffort(reasoningEffort) - if normalized == "" { - return nil - } - return &normalized - } - - value := deriveOpenAIReasoningEffortFromModel(requestedModel) - if value == "" { - return nil - } - return &value -} - func extractOpenAIServiceTier(reqBody map[string]any) *string { if reqBody == nil { return nil diff --git a/backend/internal/service/openai_messages_dispatch_test.go b/backend/internal/service/openai_messages_dispatch_test.go index a625aaddd43..7e3350f4029 100644 --- a/backend/internal/service/openai_messages_dispatch_test.go +++ b/backend/internal/service/openai_messages_dispatch_test.go @@ -1,8 +1,10 @@ package service -import "testing" +import ( + "testing" -import "github.com/stretchr/testify/require" + "github.com/stretchr/testify/require" +) func TestNormalizeOpenAIMessagesDispatchModelConfig(t *testing.T) { t.Parallel() diff --git a/backend/internal/service/openai_record_usage.go b/backend/internal/service/openai_record_usage.go new file mode 100644 index 00000000000..da6c7475937 --- /dev/null +++ b/backend/internal/service/openai_record_usage.go @@ -0,0 +1,417 @@ +package service + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" +) + +// OpenAIForwardResult represents the result of forwarding +type OpenAIForwardResult struct { + RequestID string + ResponseID string + Usage OpenAIUsage + Model string // 原始模型(用于响应和日志显示) + // BillingModel is the model used for cost calculation. + // When non-empty, CalculateCost uses this instead of Model. + // This is set by the Anthropic Messages conversion path where + // the mapped upstream model differs from the client-facing model. + BillingModel string + // UpstreamModel is the actual model sent to the upstream provider after mapping. + // Empty when no mapping was applied (requested model was used as-is). + UpstreamModel string + // ServiceTier records the OpenAI Responses API service tier, e.g. "priority" / "flex". + // Nil means the request did not specify a recognized tier. + ServiceTier *string + // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. + // Stored for usage records display; nil means not provided / not applicable. + ReasoningEffort *string + Stream bool + OpenAIWSMode bool + ResponseHeaders http.Header + Duration time.Duration + FirstTokenMs *int + ImageCount int + ImageSize string +} + +func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + return + } + // 选择性解析:仅在数据中包含终止事件标识时才进入字段提取。 + if len(data) < 72 { + return + } + eventType := gjson.GetBytes(data, "type").String() + if eventType != "response.completed" && eventType != "response.done" && + eventType != "response.incomplete" && eventType != "response.cancelled" && eventType != "response.canceled" { + return + } + + usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) + usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) + usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) + usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int()) +} + +func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + if len(body) == 0 || !gjson.ValidBytes(body) { + return OpenAIUsage{}, false + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + "usage.output_tokens_details.image_tokens", + ) + return OpenAIUsage{ + InputTokens: int(values[0].Int()), + OutputTokens: int(values[1].Int()), + CacheReadInputTokens: int(values[2].Int()), + ImageOutputTokens: int(values[3].Int()), + }, true +} + +func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { + usage := &OpenAIUsage{} + forEachOpenAISSEDataPayload(body, func(data []byte) { + s.parseSSEUsageBytes(data, usage) + }) + return usage +} + +// OpenAIRecordUsageInput input for recording usage +type OpenAIRecordUsageInput struct { + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + InboundEndpoint string + UpstreamEndpoint string + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + RequestPayloadHash string + APIKeyService APIKeyQuotaUpdater + ChannelUsageFields +} + +func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( + ctx context.Context, + result *OpenAIForwardResult, + apiKey *APIKey, + billingModels []string, + multiplier float64, + imageMultiplier float64, + tokens UsageTokens, + serviceTier string, +) (*CostBreakdown, error) { + billingModel := firstUsageBillingModel(billingModels) + if result != nil && result.ImageCount > 0 { + return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil + } + if len(billingModels) == 0 || billingModel == "" { + return nil, errors.New("openai usage billing model is empty") + } + var lastErr error + for _, candidate := range billingModels { + candidate = strings.TrimSpace(candidate) + if candidate == "" { + continue + } + cost, err := s.calculateOpenAIRecordUsageTokenCost(ctx, apiKey, candidate, multiplier, tokens, serviceTier) + if err == nil { + return cost, nil + } + lastErr = err + } + if lastErr == nil { + lastErr = errors.New("no non-empty billing model candidates") + } + return nil, fmt.Errorf("calculate OpenAI usage cost failed for billing models %s: %w", strings.Join(billingModels, ","), lastErr) +} + +func isUsagePricingUnavailableError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, ErrModelPricingUnavailable) { + return true + } + msg := strings.ToLower(err.Error()) + return strings.Contains(msg, "no pricing available") || strings.Contains(msg, "pricing not found") +} + +func (s *OpenAIGatewayService) calculateOpenAIRecordUsageTokenCost( + ctx context.Context, + apiKey *APIKey, + billingModel string, + multiplier float64, + tokens UsageTokens, + serviceTier string, +) (*CostBreakdown, error) { + if s.resolver != nil && apiKey.Group != nil { + gid := apiKey.Group.ID + return s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + Tokens: tokens, + RequestCount: 1, + RateMultiplier: multiplier, + ServiceTier: serviceTier, + Resolver: s.resolver, + }) + } + return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) +} + +func (s *OpenAIGatewayService) calculateOpenAIImageCost( + ctx context.Context, + billingModel string, + apiKey *APIKey, + result *OpenAIForwardResult, + multiplier float64, +) *CostBreakdown { + if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil && + (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) { + gid := apiKey.Group.ID + cost, err := s.billingService.CalculateCostUnified(CostInput{ + Ctx: ctx, + Model: billingModel, + GroupID: &gid, + RequestCount: result.ImageCount, + SizeTier: result.ImageSize, + RateMultiplier: multiplier, + Resolver: s.resolver, + Resolved: resolved, + }) + if err == nil { + return cost + } + logger.LegacyPrintf("service.openai_gateway", "Calculate image channel cost failed: %v", err) + } + + var groupConfig *ImagePriceConfig + if apiKey != nil && apiKey.Group != nil { + groupConfig = &ImagePriceConfig{ + Price1K: apiKey.Group.ImagePrice1K, + Price2K: apiKey.Group.ImagePrice2K, + Price4K: apiKey.Group.ImagePrice4K, + } + } + return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) +} + +func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { + if s.resolver == nil || apiKey == nil || apiKey.Group == nil { + return nil + } + gid := apiKey.Group.ID + resolved := s.resolver.Resolve(ctx, PricingInput{Model: billingModel, GroupID: &gid}) + if resolved.Source == PricingSourceChannel { + return resolved + } + return nil +} + +// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. +// Exported for use in ratelimit_service when handling OpenAI 429 responses. +func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { + snapshot := &OpenAICodexUsageSnapshot{} + hasData := false + + // Helper to parse float64 from header + parseFloat := func(key string) *float64 { + if v := headers.Get(key); v != "" { + if f, err := strconv.ParseFloat(v, 64); err == nil { + return &f + } + } + return nil + } + + // Helper to parse int from header + parseInt := func(key string) *int { + if v := headers.Get(key); v != "" { + if i, err := strconv.Atoi(v); err == nil { + return &i + } + } + return nil + } + + // Primary (weekly) limits + if v := parseFloat("x-codex-primary-used-percent"); v != nil { + snapshot.PrimaryUsedPercent = v + hasData = true + } + if v := parseInt("x-codex-primary-reset-after-seconds"); v != nil { + snapshot.PrimaryResetAfterSeconds = v + hasData = true + } + if v := parseInt("x-codex-primary-window-minutes"); v != nil { + snapshot.PrimaryWindowMinutes = v + hasData = true + } + + // Secondary (5h) limits + if v := parseFloat("x-codex-secondary-used-percent"); v != nil { + snapshot.SecondaryUsedPercent = v + hasData = true + } + if v := parseInt("x-codex-secondary-reset-after-seconds"); v != nil { + snapshot.SecondaryResetAfterSeconds = v + hasData = true + } + if v := parseInt("x-codex-secondary-window-minutes"); v != nil { + snapshot.SecondaryWindowMinutes = v + hasData = true + } + + // Overflow ratio + if v := parseFloat("x-codex-primary-over-secondary-limit-percent"); v != nil { + snapshot.PrimaryOverSecondaryPercent = v + hasData = true + } + + if !hasData { + return nil + } + + snapshot.UpdatedAt = time.Now().Format(time.RFC3339) + return snapshot +} + +func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time { + if snapshot == nil { + return fallback + } + if snapshot.UpdatedAt == "" { + return fallback + } + base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt) + if err != nil { + return fallback + } + return base +} + +func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string { + if resetAfterSeconds == nil { + return nil + } + sec := *resetAfterSeconds + if sec < 0 { + sec = 0 + } + resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339) + return &resetAt +} + +func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any { + if snapshot == nil { + return nil + } + + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) + updates := make(map[string]any) + + // 保存原始 primary/secondary 字段,便于排查问题 + if snapshot.PrimaryUsedPercent != nil { + updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent + } + if snapshot.PrimaryResetAfterSeconds != nil { + updates["codex_primary_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds + } + if snapshot.PrimaryWindowMinutes != nil { + updates["codex_primary_window_minutes"] = *snapshot.PrimaryWindowMinutes + } + if snapshot.SecondaryUsedPercent != nil { + updates["codex_secondary_used_percent"] = *snapshot.SecondaryUsedPercent + } + if snapshot.SecondaryResetAfterSeconds != nil { + updates["codex_secondary_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds + } + if snapshot.SecondaryWindowMinutes != nil { + updates["codex_secondary_window_minutes"] = *snapshot.SecondaryWindowMinutes + } + if snapshot.PrimaryOverSecondaryPercent != nil { + updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent + } + updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339) + + // 归一化到 5h/7d 规范字段 + if normalized := snapshot.Normalize(); normalized != nil { + if normalized.Used5hPercent != nil { + updates["codex_5h_used_percent"] = *normalized.Used5hPercent + } + if normalized.Reset5hSeconds != nil { + updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds + } + if normalized.Window5hMinutes != nil { + updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes + } + if normalized.Used7dPercent != nil { + updates["codex_7d_used_percent"] = *normalized.Used7dPercent + } + if normalized.Reset7dSeconds != nil { + updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds + } + if normalized.Window7dMinutes != nil { + updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes + } + if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil { + updates["codex_5h_reset_at"] = *reset5hAt + } + if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil { + updates["codex_7d_reset_at"] = *reset7dAt + } + } + + return updates +} + +// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field +func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { + if snapshot == nil { + return + } + if s == nil || s.accountRepo == nil { + return + } + + now := time.Now() + updates := buildCodexUsageExtraUpdates(snapshot, now) + if len(updates) == 0 { + return + } + if !s.getCodexSnapshotThrottle().Allow(accountID, now) { + return + } + + go func() { + updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) + }() +} + +func (s *OpenAIGatewayService) UpdateCodexUsageSnapshotFromHeaders(ctx context.Context, accountID int64, headers http.Header) { + if accountID <= 0 || headers == nil { + return + } + if snapshot := ParseCodexRateLimitHeaders(headers); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, accountID, snapshot) + } +} diff --git a/backend/internal/service/openai_session_hash.go b/backend/internal/service/openai_session_hash.go new file mode 100644 index 00000000000..e69c4cedd49 --- /dev/null +++ b/backend/internal/service/openai_session_hash.go @@ -0,0 +1,125 @@ +package service + +import ( + "fmt" + "strings" + + "github.com/cespare/xxhash/v2" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +// isolateOpenAISessionID 将 apiKeyID 混入 session 标识符, +// 确保不同 API Key 的用户即使使用相同的原始 session_id/conversation_id, +// 到达上游的标识符也不同,防止跨用户会话碰撞。 +func isolateOpenAISessionID(apiKeyID int64, raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + h := xxhash.New() + _, _ = fmt.Fprintf(h, "k%d:", apiKeyID) + _, _ = h.WriteString(raw) + return fmt.Sprintf("%016x", h.Sum64()) +} + +// ExtractSessionID extracts the raw session ID from headers or body without hashing. +// Used by ForwardAsAnthropic to pass as prompt_cache_key for upstream cache. +func (s *OpenAIGatewayService) ExtractSessionID(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + return sessionID +} + +func explicitOpenAISessionID(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + return sessionID +} + +// GenerateExplicitSessionHash generates a sticky-session hash only from explicit +// client session signals. It intentionally skips content-derived fallback and is +// used by stateless endpoints such as /v1/images. +func (s *OpenAIGatewayService) GenerateExplicitSessionHash(c *gin.Context, body []byte) string { + sessionID := explicitOpenAISessionID(c, body) + if sessionID == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +// GenerateSessionHash generates a sticky-session hash for OpenAI requests. +// +// Priority: +// 1. Header: session_id +// 2. Header: conversation_id +// 3. Body: prompt_cache_key (opencode) +// 4. Body: content-based fallback (model + system + tools + first user message) +func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + + sessionID := explicitOpenAISessionID(c, body) + if sessionID == "" && len(body) > 0 { + sessionID = deriveOpenAIContentSessionSeed(body) + } + if sessionID == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +// GenerateSessionHashWithFallback 先按常规信号生成会话哈希; +// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。 +// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。 +func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string { + sessionHash := s.GenerateSessionHash(c, body) + if sessionHash != "" { + return sessionHash + } + + seed := strings.TrimSpace(fallbackSeed) + if seed == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(seed) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +func resolveOpenAIUpstreamOriginator(c *gin.Context, isOfficialClient bool) string { + if c != nil { + if originator := strings.TrimSpace(c.GetHeader("originator")); originator != "" { + return originator + } + } + if isOfficialClient { + return "codex_cli_rs" + } + return "opencode" +} diff --git a/backend/internal/service/openai_ws_connection.go b/backend/internal/service/openai_ws_connection.go new file mode 100644 index 00000000000..e6a1103b71e --- /dev/null +++ b/backend/internal/service/openai_ws_connection.go @@ -0,0 +1,159 @@ +package service + +import ( + "errors" + "fmt" + "net/url" + "strings" + "time" +) + +func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { + if s == nil { + return nil + } + s.openaiWSPoolOnce.Do(func() { + if s.openaiWSPool == nil { + s.openaiWSPool = newOpenAIWSConnPool(s.cfg) + } + }) + return s.openaiWSPool +} + +func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer { + if s == nil { + return nil + } + s.openaiWSPassthroughDialerOnce.Do(func() { + if s.openaiWSPassthroughDialer == nil { + s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer() + } + }) + return s.openaiWSPassthroughDialer +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + if pool == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return pool.SnapshotMetrics() +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + snapshot := OpenAIWSPerformanceMetricsSnapshot{ + Retry: s.SnapshotOpenAIWSRetryMetrics(), + } + if pool == nil { + return snapshot + } + snapshot.Pool = pool.SnapshotMetrics() + snapshot.Transport = pool.SnapshotTransportMetrics() + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { + if s == nil { + return nil + } + s.openaiWSStateStoreOnce.Do(func() { + if s.openaiWSStateStore == nil { + s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) + } + }) + return s.openaiWSStateStore +} + +func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + var targetURL string + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + default: + targetURL = openaiPlatformAPIURL + } + + parsed, err := url.Parse(strings.TrimSpace(targetURL)) + if err != nil { + return "", fmt.Errorf("invalid target url: %w", err) + } + switch strings.ToLower(parsed.Scheme) { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + case "wss", "ws": + // 保持不变 + default: + return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) + } + return parsed.String(), nil +} + +func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Second + } + seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { + if s == nil || accountID <= 0 { + return false + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return false + } + rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) + if !ok || rawUntil == nil { + return false + } + until, ok := rawUntil.(time.Time) + if !ok || until.IsZero() { + s.openaiWSFallbackUntil.Delete(accountID) + return false + } + if time.Now().Before(until) { + return true + } + s.openaiWSFallbackUntil.Delete(accountID) + return false +} + +func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { + if s == nil || accountID <= 0 { + return + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return + } + s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) +} + +func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiWSFallbackUntil.Delete(accountID) +} diff --git a/backend/internal/service/openai_ws_errors.go b/backend/internal/service/openai_ws_errors.go new file mode 100644 index 00000000000..131989e816f --- /dev/null +++ b/backend/internal/service/openai_ws_errors.go @@ -0,0 +1,353 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + + coderws "github.com/coder/websocket" +) + +// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 +type openAIWSFallbackError struct { + Reason string + Err error +} + +func wrapOpenAIWSFallback(reason string, err error) error { + return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} +} + +// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 +type OpenAIWSClientCloseError struct { + statusCode coderws.StatusCode + reason string + err error +} + +type openAIWSIngressTurnError struct { + stage string + cause error + wroteDownstream bool +} + +func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { + if cause == nil { + return nil + } + return &openAIWSIngressTurnError{ + stage: strings.TrimSpace(stage), + cause: cause, + wroteDownstream: wroteDownstream, + } +} + +// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 +func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { + return &OpenAIWSClientCloseError{ + statusCode: statusCode, + reason: strings.TrimSpace(reason), + err: err, + } +} + +func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reasonText := strings.TrimSpace(closeErr.Reason) + if reasonText != "" { + closeReason = normalizeOpenAIWSLogValue(reasonText) + } + } + return normalizeOpenAIWSLogValue(closeStatus), closeReason +} + +func unwrapOpenAIWSDialBaseError(err error) error { + if err == nil { + return nil + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { + return dialErr.Err + } + return err +} + +func openAIWSDialRespHeaderForLog(err error, key string) string { + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { + return "-" + } + return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) +} + +func classifyOpenAIWSDialError(err error) string { + if err == nil { + return "-" + } + baseErr := unwrapOpenAIWSDialBaseError(err) + if baseErr == nil { + return "-" + } + if errors.Is(baseErr, context.DeadlineExceeded) { + return "ctx_deadline_exceeded" + } + if errors.Is(baseErr, context.Canceled) { + return "ctx_canceled" + } + var netErr net.Error + if errors.As(baseErr, &netErr) && netErr.Timeout() { + return "net_timeout" + } + if status := coderws.CloseStatus(baseErr); status != -1 { + return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) + } + message := strings.ToLower(strings.TrimSpace(baseErr.Error())) + switch { + case strings.Contains(message, "handshake not finished"): + return "handshake_not_finished" + case strings.Contains(message, "bad handshake"): + return "bad_handshake" + case strings.Contains(message, "connection refused"): + return "connection_refused" + case strings.Contains(message, "no such host"): + return "dns_not_found" + case strings.Contains(message, "tls"): + return "tls_error" + case strings.Contains(message, "i/o timeout"): + return "io_timeout" + case strings.Contains(message, "context deadline exceeded"): + return "ctx_deadline_exceeded" + default: + return "dial_error" + } +} + +func summarizeOpenAIWSDialError(err error) ( + statusCode int, + dialClass string, + closeStatus string, + closeReason string, + respServer string, + respVia string, + respCFRay string, + respRequestID string, +) { + dialClass = "-" + closeStatus = "-" + closeReason = "-" + respServer = "-" + respVia = "-" + respCFRay = "-" + respRequestID = "-" + if err == nil { + return + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil { + statusCode = dialErr.StatusCode + respServer = openAIWSDialRespHeaderForLog(err, "server") + respVia = openAIWSDialRespHeaderForLog(err, "via") + respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") + respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") + } + dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) + closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) + return +} + +func isOpenAIWSClientDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") || + strings.Contains(message, "an established connection was aborted") +} + +func classifyOpenAIWSReadFallbackReason(err error) string { + if err == nil { + return "read_event" + } + switch coderws.CloseStatus(err) { + case coderws.StatusPolicyViolation: + return "policy_violation" + case coderws.StatusMessageTooBig: + return "message_too_big" + default: + return "read_event" + } +} + +func classifyOpenAIWSAcquireError(err error) string { + if err == nil { + return "acquire_conn" + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) { + switch dialErr.StatusCode { + case 426: + return "upgrade_required" + case 401, 403: + return "auth_failed" + case 429: + return "upstream_rate_limited" + } + if dialErr.StatusCode >= 500 { + return "upstream_5xx" + } + return "dial_failed" + } + if errors.Is(err, errOpenAIWSConnQueueFull) { + return "conn_queue_full" + } + if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { + return "preferred_conn_unavailable" + } + if errors.Is(err, context.DeadlineExceeded) { + return "acquire_timeout" + } + return "acquire_conn" +} + +func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") { + return true + } + if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") { + return true + } + if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") { + return true + } + if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) { + return true + } + return false +} + +func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) { + if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI { + return + } + if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return + } + s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) +} + +func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + switch code { + case "upgrade_required": + return "upgrade_required", true + case "websocket_not_supported", "websocket_unsupported": + return "ws_unsupported", true + case "websocket_connection_limit_reached": + return "ws_connection_limit_reached", true + case "invalid_encrypted_content": + return "invalid_encrypted_content", true + case "previous_response_not_found": + return "previous_response_not_found", true + } + if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { + return "upstream_rate_limited", false + } + if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { + return "upgrade_required", true + } + if strings.Contains(errType, "upgrade") { + return "upgrade_required", true + } + if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { + return "ws_unsupported", true + } + if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { + return "ws_connection_limit_reached", true + } + if strings.Contains(msg, "invalid_encrypted_content") || + (strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) { + return "invalid_encrypted_content", true + } + if strings.Contains(msg, "previous_response_not_found") || + (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { + return "previous_response_not_found", true + } + if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { + return "upstream_error_event", true + } + return "event_error", false +} + +func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { + if len(message) == 0 { + return "event_error", false + } + return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + switch { + case strings.Contains(errType, "invalid_request"), + strings.Contains(code, "invalid_request"), + strings.Contains(code, "bad_request"), + code == "invalid_encrypted_content", + code == "previous_response_not_found": + return http.StatusBadRequest + case strings.Contains(errType, "authentication"), + strings.Contains(code, "invalid_api_key"), + strings.Contains(code, "unauthorized"): + return http.StatusUnauthorized + case strings.Contains(errType, "permission"), + strings.Contains(code, "forbidden"): + return http.StatusForbidden + case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""): + return http.StatusTooManyRequests + default: + return http.StatusBadGateway + } +} + +func openAIWSErrorHTTPStatus(message []byte) int { + if len(message) == 0 { + return http.StatusBadGateway + } + codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) + return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 77cf7d95a3f..8eb74c78a6a 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -6,9 +6,7 @@ import ( "encoding/json" "errors" "fmt" - "io" "math/rand" - "net" "net/http" "net/url" "sort" @@ -68,12 +66,6 @@ var openAIWSLogValueReplacer = strings.NewReplacer( var openAIWSIngressPreflightPingIdle = 20 * time.Second -// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 -type openAIWSFallbackError struct { - Reason string - Err error -} - func (e *openAIWSFallbackError) Error() string { if e == nil { return "" @@ -91,23 +83,6 @@ func (e *openAIWSFallbackError) Unwrap() error { return e.Err } -func wrapOpenAIWSFallback(reason string, err error) error { - return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} -} - -// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 -type OpenAIWSClientCloseError struct { - statusCode coderws.StatusCode - reason string - err error -} - -type openAIWSIngressTurnError struct { - stage string - cause error - wroteDownstream bool -} - func (e *openAIWSIngressTurnError) Error() string { if e == nil { return "" @@ -125,67 +100,6 @@ func (e *openAIWSIngressTurnError) Unwrap() error { return e.cause } -func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { - if cause == nil { - return nil - } - return &openAIWSIngressTurnError{ - stage: strings.TrimSpace(stage), - cause: cause, - wroteDownstream: wroteDownstream, - } -} - -func isOpenAIWSIngressTurnRetryable(err error) bool { - var turnErr *openAIWSIngressTurnError - if !errors.As(err, &turnErr) || turnErr == nil { - return false - } - if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { - return false - } - if turnErr.wroteDownstream { - return false - } - switch turnErr.stage { - case "write_upstream", "read_upstream": - return true - default: - return false - } -} - -func openAIWSIngressTurnRetryReason(err error) string { - var turnErr *openAIWSIngressTurnError - if !errors.As(err, &turnErr) || turnErr == nil { - return "unknown" - } - if turnErr.stage == "" { - return "unknown" - } - return turnErr.stage -} - -func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { - var turnErr *openAIWSIngressTurnError - if !errors.As(err, &turnErr) || turnErr == nil { - return false - } - if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { - return false - } - return !turnErr.wroteDownstream -} - -// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 -func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { - return &OpenAIWSClientCloseError{ - statusCode: statusCode, - reason: strings.TrimSpace(reason), - err: err, - } -} - func (e *OpenAIWSClientCloseError) Error() string { if e == nil { return "" @@ -235,17 +149,6 @@ func normalizeOpenAIWSLogValue(value string) string { return openAIWSLogValueReplacer.Replace(trimmed) } -func truncateOpenAIWSLogValue(value string, maxLen int) string { - normalized := normalizeOpenAIWSLogValue(value) - if normalized == "-" || maxLen <= 0 { - return normalized - } - if len(normalized) <= maxLen { - return normalized - } - return normalized[:maxLen] + "..." -} - func openAIWSHeaderValueForLog(headers http.Header, key string) string { if headers == nil { return "-" @@ -260,1444 +163,179 @@ func hasOpenAIWSHeader(headers http.Header, key string) bool { return strings.TrimSpace(headers.Get(key)) != "" } -type openAIWSSessionHeaderResolution struct { - SessionID string - ConversationID string - SessionSource string - ConversationSource string -} - -func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { - resolution := openAIWSSessionHeaderResolution{ - SessionSource: "none", - ConversationSource: "none", - } - if c != nil && c.Request != nil { - if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { - resolution.SessionID = sessionID - resolution.SessionSource = "header_session_id" - } - if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { - resolution.ConversationID = conversationID - resolution.ConversationSource = "header_conversation_id" - if resolution.SessionID == "" { - resolution.SessionID = conversationID - resolution.SessionSource = "header_conversation_id" - } - } - } - - cacheKey := strings.TrimSpace(promptCacheKey) - if cacheKey != "" { - if resolution.SessionID == "" { - resolution.SessionID = cacheKey - resolution.SessionSource = "prompt_cache_key" - } - } - return resolution -} - -func shouldLogOpenAIWSEvent(idx int, eventType string) bool { - if idx <= openAIWSEventLogHeadLimit { - return true - } - if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { - return true +func sortedKeys(m map[string]any) []string { + if len(m) == 0 { + return nil } - if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { - return true + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) } - return false + sort.Strings(keys) + return keys } -func shouldLogOpenAIWSBufferedEvent(idx int) bool { - if idx <= openAIWSBufferLogHeadLimit { - return true - } - if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { - return true - } - return false +type OpenAIWSPerformanceMetricsSnapshot struct { + Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` + Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` + Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` } -func openAIWSEventMayContainModel(eventType string) bool { - switch eventType { - case "response.created", - "response.in_progress", - "response.completed", - "response.done", - "response.failed", - "response.incomplete", - "response.cancelled", - "response.canceled": - return true - default: - trimmed := strings.TrimSpace(eventType) - if trimmed == eventType { - return false - } - switch trimmed { - case "response.created", - "response.in_progress", - "response.completed", - "response.done", - "response.failed", - "response.incomplete", - "response.cancelled", - "response.canceled": - return true - default: - return false +func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { + if s != nil && s.cfg != nil { + seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds + if seconds > 0 { + return time.Duration(seconds) * time.Second } } + return time.Hour } -func openAIWSEventMayContainToolCalls(eventType string) bool { - eventType = strings.TrimSpace(eventType) - if eventType == "" { - return false - } - if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { - return true - } - switch eventType { - case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": - return true - default: - return false - } -} - -func openAIWSEventShouldParseUsage(eventType string) bool { - return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" -} - -func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { - if len(message) == 0 { - return "", "", gjson.Result{} - } - values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") - eventType = strings.TrimSpace(values[0].String()) - if id := strings.TrimSpace(values[1].String()); id != "" { - responseID = id - } else { - responseID = strings.TrimSpace(values[2].String()) - } - return eventType, responseID, values[3] -} - -func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { - if len(message) == 0 { - return false - } - return bytes.Contains(message, []byte(`"tool_calls"`)) || - bytes.Contains(message, []byte(`"tool_call"`)) || - bytes.Contains(message, []byte(`"function_call"`)) -} - -func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { - if usage == nil || len(message) == 0 { - return - } - values := gjson.GetManyBytes( - message, - "response.usage.input_tokens", - "response.usage.output_tokens", - "response.usage.input_tokens_details.cached_tokens", - ) - usage.InputTokens = int(values[0].Int()) - usage.OutputTokens = int(values[1].Int()) - usage.CacheReadInputTokens = int(values[2].Int()) -} - -func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { - if len(message) == 0 { - return "", "", "" - } - values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") - return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) -} - -func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { - code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) - errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) - errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) - return code, errType, errMessage -} - -func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { - if len(message) == 0 { - return "-", "-", "-" - } - return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) -} - -func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { - if len(payload) == 0 { - return "-" - } - type keySize struct { - Key string - Size int - } - sizes := make([]keySize, 0, len(payload)) - for key, value := range payload { - size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) - sizes = append(sizes, keySize{Key: key, Size: size}) - } - sort.Slice(sizes, func(i, j int) bool { - if sizes[i].Size == sizes[j].Size { - return sizes[i].Key < sizes[j].Key - } - return sizes[i].Size > sizes[j].Size - }) - - if topN <= 0 || topN > len(sizes) { - topN = len(sizes) - } - parts := make([]string, 0, topN) - for idx := 0; idx < topN; idx++ { - item := sizes[idx] - parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) +func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled } - return strings.Join(parts, ",") + return true } -func estimateOpenAIWSPayloadValueSize(value any, depth int) int { - if depth <= 0 { - return -1 - } - switch v := value.(type) { - case nil: - return 0 - case string: - return len(v) - case []byte: - return len(v) - case bool: - return 1 - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return 8 - case float32, float64: - return 8 - case map[string]any: - if len(v) == 0 { - return 2 - } - total := 2 - count := 0 - for key, item := range v { - count++ - if count > openAIWSPayloadSizeEstimateMaxItems { - return -1 - } - itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) - if itemSize < 0 { - return -1 - } - total += len(key) + itemSize + 3 - if total > openAIWSPayloadSizeEstimateMaxBytes { - return -1 - } - } - return total - case []any: - if len(v) == 0 { - return 2 - } - total := 2 - limit := len(v) - if limit > openAIWSPayloadSizeEstimateMaxItems { - return -1 - } - for i := 0; i < limit; i++ { - itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) - if itemSize < 0 { - return -1 - } - total += itemSize + 1 - if total > openAIWSPayloadSizeEstimateMaxBytes { - return -1 - } - } - return total - default: - raw, err := json.Marshal(v) - if err != nil { - return -1 - } - if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { - return -1 - } - return len(raw) +func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second } + return 15 * time.Minute } -func openAIWSPayloadString(payload map[string]any, key string) string { - if len(payload) == 0 { - return "" - } - raw, ok := payload[key] - if !ok { - return "" - } - switch v := raw.(type) { - case nil: - return "" - case string: - return strings.TrimSpace(v) - case []byte: - return strings.TrimSpace(string(v)) - default: - return "" +func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration { + if timeout := s.openAIWSReadTimeout(); timeout > 0 { + return timeout } + return openAIWSPassthroughIdleTimeoutDefault } -func openAIWSPayloadStringFromRaw(payload []byte, key string) string { - if len(payload) == 0 || strings.TrimSpace(key) == "" { - return "" +func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second } - return strings.TrimSpace(gjson.GetBytes(payload, key).String()) + return 2 * time.Minute } -func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { - if len(payload) == 0 || strings.TrimSpace(key) == "" { - return defaultValue - } - value := gjson.GetBytes(payload, key) - if !value.Exists() { - return defaultValue - } - if value.Type != gjson.True && value.Type != gjson.False { - return defaultValue +func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { + return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize } - return value.Bool() -} - -func openAIWSSessionHashesFromID(sessionID string) (string, string) { - return deriveOpenAISessionHashes(sessionID) + return openAIWSEventFlushBatchSizeDefault } -func extractOpenAIWSImageURL(value any) string { - switch v := value.(type) { - case string: - return strings.TrimSpace(v) - case map[string]any: - if raw, ok := v["url"].(string); ok { - return strings.TrimSpace(raw) +func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { + if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { + return 0 } + return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond } - return "" + return openAIWSEventFlushIntervalDefault } -func summarizeOpenAIWSInput(input any) string { - items, ok := input.([]any) - if !ok || len(items) == 0 { - return "-" - } - - itemCount := len(items) - textChars := 0 - imageDataURLs := 0 - imageDataURLChars := 0 - imageRemoteURLs := 0 - - handleContentItem := func(contentItem map[string]any) { - contentType, _ := contentItem["type"].(string) - switch strings.TrimSpace(contentType) { - case "input_text", "output_text", "text": - if text, ok := contentItem["text"].(string); ok { - textChars += len(text) - } - case "input_image": - imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) - if imageURL == "" { - return - } - if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { - imageDataURLs++ - imageDataURLChars += len(imageURL) - return - } - imageRemoteURLs++ - } - } - - handleInputItem := func(inputItem map[string]any) { - if content, ok := inputItem["content"].([]any); ok { - for _, rawContent := range content { - contentItem, ok := rawContent.(map[string]any) - if !ok { - continue - } - handleContentItem(contentItem) - } - return - } - - itemType, _ := inputItem["type"].(string) - switch strings.TrimSpace(itemType) { - case "input_text", "output_text", "text": - if text, ok := inputItem["text"].(string); ok { - textChars += len(text) - } - case "input_image": - imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) - if imageURL == "" { - return - } - if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { - imageDataURLs++ - imageDataURLChars += len(imageURL) - return - } - imageRemoteURLs++ +func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { + if s != nil && s.cfg != nil { + rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate + if rate < 0 { + return 0 } - } - - for _, rawItem := range items { - inputItem, ok := rawItem.(map[string]any) - if !ok { - continue + if rate > 1 { + return 1 } - handleInputItem(inputItem) - } - - return fmt.Sprintf( - "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", - itemCount, - textChars, - imageDataURLs, - imageDataURLChars, - imageRemoteURLs, - ) -} - -func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { - if len(payload) == 0 || strings.TrimSpace(key) == "" { - return - } - if _, exists := payload[key]; !exists { - return + return rate } - delete(payload, key) - *removed = append(*removed, key) + return openAIWSPayloadLogSampleDefault } -// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, -// 避免重试成功却改变原始请求语义。 -// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 -func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { - if len(payload) == 0 { - return "empty", nil - } +func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { + // 首次尝试保留一条完整 payload_schema 便于排障。 if attempt <= 1 { - return "full", nil + return true } - - removed := make([]string, 0, 2) - if attempt >= 2 { - dropOpenAIWSPayloadKey(payload, "include", &removed) + rate := s.openAIWSPayloadLogSampleRate() + if rate <= 0 { + return false } - - if len(removed) == 0 { - return "full", nil + if rate >= 1 { + return true } - sort.Strings(removed) - return "trim_optional_fields", removed -} - -func logOpenAIWSModeInfo(format string, args ...any) { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) + return rand.Float64() < rate } -func isOpenAIWSModeDebugEnabled() bool { +func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { + if !s.shouldLogOpenAIWSPayloadSchema(attempt) { + return false + } return logger.L().Core().Enabled(zap.DebugLevel) } -func logOpenAIWSModeDebug(format string, args ...any) { - if !isOpenAIWSModeDebugEnabled() { - return +func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second } - logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) + return 10 * time.Second } -func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { - if err == nil { - return +func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { + // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 + // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 + dial := s.openAIWSDialTimeout() + if dial <= 0 { + dial = 10 * time.Second } - logger.L().Warn( - "openai.ws_bind_response_account_failed", - zap.Int64("group_id", groupID), - zap.Int64("account_id", accountID), - zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), - zap.Error(err), - ) -} - -func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { - if err == nil { - return "-", "-" - } - statusCode := coderws.CloseStatus(err) - if statusCode == -1 { - return "-", "-" - } - closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) - closeReason := "-" - var closeErr coderws.CloseError - if errors.As(err, &closeErr) { - reasonText := strings.TrimSpace(closeErr.Reason) - if reasonText != "" { - closeReason = normalizeOpenAIWSLogValue(reasonText) - } - } - return normalizeOpenAIWSLogValue(closeStatus), closeReason -} - -func unwrapOpenAIWSDialBaseError(err error) error { - if err == nil { - return nil - } - var dialErr *openAIWSDialError - if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { - return dialErr.Err - } - return err -} - -func openAIWSDialRespHeaderForLog(err error, key string) string { - var dialErr *openAIWSDialError - if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { - return "-" - } - return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) -} - -func classifyOpenAIWSDialError(err error) string { - if err == nil { - return "-" - } - baseErr := unwrapOpenAIWSDialBaseError(err) - if baseErr == nil { - return "-" - } - if errors.Is(baseErr, context.DeadlineExceeded) { - return "ctx_deadline_exceeded" - } - if errors.Is(baseErr, context.Canceled) { - return "ctx_canceled" - } - var netErr net.Error - if errors.As(baseErr, &netErr) && netErr.Timeout() { - return "net_timeout" - } - if status := coderws.CloseStatus(baseErr); status != -1 { - return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) - } - message := strings.ToLower(strings.TrimSpace(baseErr.Error())) - switch { - case strings.Contains(message, "handshake not finished"): - return "handshake_not_finished" - case strings.Contains(message, "bad handshake"): - return "bad_handshake" - case strings.Contains(message, "connection refused"): - return "connection_refused" - case strings.Contains(message, "no such host"): - return "dns_not_found" - case strings.Contains(message, "tls"): - return "tls_error" - case strings.Contains(message, "i/o timeout"): - return "io_timeout" - case strings.Contains(message, "context deadline exceeded"): - return "ctx_deadline_exceeded" - default: - return "dial_error" - } -} - -func summarizeOpenAIWSDialError(err error) ( - statusCode int, - dialClass string, - closeStatus string, - closeReason string, - respServer string, - respVia string, - respCFRay string, - respRequestID string, -) { - dialClass = "-" - closeStatus = "-" - closeReason = "-" - respServer = "-" - respVia = "-" - respCFRay = "-" - respRequestID = "-" - if err == nil { - return - } - var dialErr *openAIWSDialError - if errors.As(err, &dialErr) && dialErr != nil { - statusCode = dialErr.StatusCode - respServer = openAIWSDialRespHeaderForLog(err, "server") - respVia = openAIWSDialRespHeaderForLog(err, "via") - respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") - respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") - } - dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) - closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) - return -} - -func isOpenAIWSClientDisconnectError(err error) bool { - if err == nil { - return false - } - if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { - return true - } - switch coderws.CloseStatus(err) { - case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: - return true - } - message := strings.ToLower(strings.TrimSpace(err.Error())) - if message == "" { - return false - } - return strings.Contains(message, "failed to read frame header: eof") || - strings.Contains(message, "unexpected eof") || - strings.Contains(message, "use of closed network connection") || - strings.Contains(message, "connection reset by peer") || - strings.Contains(message, "broken pipe") || - strings.Contains(message, "an established connection was aborted") -} - -func classifyOpenAIWSReadFallbackReason(err error) string { - if err == nil { - return "read_event" - } - switch coderws.CloseStatus(err) { - case coderws.StatusPolicyViolation: - return "policy_violation" - case coderws.StatusMessageTooBig: - return "message_too_big" - default: - return "read_event" - } -} - -func sortedKeys(m map[string]any) []string { - if len(m) == 0 { - return nil - } - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - sort.Strings(keys) - return keys -} - -func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { - if s == nil { - return nil - } - s.openaiWSPoolOnce.Do(func() { - if s.openaiWSPool == nil { - s.openaiWSPool = newOpenAIWSConnPool(s.cfg) - } - }) - return s.openaiWSPool -} - -func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer { - if s == nil { - return nil - } - s.openaiWSPassthroughDialerOnce.Do(func() { - if s.openaiWSPassthroughDialer == nil { - s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer() - } - }) - return s.openaiWSPassthroughDialer -} - -func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { - pool := s.getOpenAIWSConnPool() - if pool == nil { - return OpenAIWSPoolMetricsSnapshot{} - } - return pool.SnapshotMetrics() -} - -type OpenAIWSPerformanceMetricsSnapshot struct { - Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` - Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` - Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` -} - -func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { - pool := s.getOpenAIWSConnPool() - snapshot := OpenAIWSPerformanceMetricsSnapshot{ - Retry: s.SnapshotOpenAIWSRetryMetrics(), - } - if pool == nil { - return snapshot - } - snapshot.Pool = pool.SnapshotMetrics() - snapshot.Transport = pool.SnapshotTransportMetrics() - return snapshot -} - -func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { - if s == nil { - return nil - } - s.openaiWSStateStoreOnce.Do(func() { - if s.openaiWSStateStore == nil { - s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) - } - }) - return s.openaiWSStateStore -} - -func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { - if s != nil && s.cfg != nil { - seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds - if seconds > 0 { - return time.Duration(seconds) * time.Second - } - } - return time.Hour -} - -func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { - if s != nil && s.cfg != nil { - return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled - } - return true -} - -func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { - if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { - return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second - } - return 15 * time.Minute -} - -func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration { - if timeout := s.openAIWSReadTimeout(); timeout > 0 { - return timeout - } - return openAIWSPassthroughIdleTimeoutDefault -} - -func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { - if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { - return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second - } - return 2 * time.Minute -} - -func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { - if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { - return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize - } - return openAIWSEventFlushBatchSizeDefault -} - -func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { - if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { - if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { - return 0 - } - return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond - } - return openAIWSEventFlushIntervalDefault -} - -func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { - if s != nil && s.cfg != nil { - rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate - if rate < 0 { - return 0 - } - if rate > 1 { - return 1 - } - return rate - } - return openAIWSPayloadLogSampleDefault -} - -func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { - // 首次尝试保留一条完整 payload_schema 便于排障。 - if attempt <= 1 { - return true - } - rate := s.openAIWSPayloadLogSampleRate() - if rate <= 0 { - return false - } - if rate >= 1 { - return true - } - return rand.Float64() < rate -} - -func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { - if !s.shouldLogOpenAIWSPayloadSchema(attempt) { - return false - } - return logger.L().Core().Enabled(zap.DebugLevel) -} - -func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { - if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { - return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second - } - return 10 * time.Second -} - -func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { - // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 - // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 - dial := s.openAIWSDialTimeout() - if dial <= 0 { - dial = 10 * time.Second - } - return dial + 2*time.Second -} - -func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { - if account == nil { - return "", errors.New("account is nil") - } - var targetURL string - switch account.Type { - case AccountTypeOAuth: - targetURL = chatgptCodexURL - case AccountTypeAPIKey: - baseURL := account.GetOpenAIBaseURL() - if baseURL == "" { - targetURL = openaiPlatformAPIURL - } else { - validatedURL, err := s.validateUpstreamBaseURL(baseURL) - if err != nil { - return "", err - } - targetURL = buildOpenAIResponsesURL(validatedURL) - } - default: - targetURL = openaiPlatformAPIURL - } - - parsed, err := url.Parse(strings.TrimSpace(targetURL)) - if err != nil { - return "", fmt.Errorf("invalid target url: %w", err) - } - switch strings.ToLower(parsed.Scheme) { - case "https": - parsed.Scheme = "wss" - case "http": - parsed.Scheme = "ws" - case "wss", "ws": - // 保持不变 - default: - return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) - } - return parsed.String(), nil -} - -func (s *OpenAIGatewayService) buildOpenAIWSHeaders( - c *gin.Context, - account *Account, - token string, - decision OpenAIWSProtocolDecision, - isCodexCLI bool, - turnState string, - turnMetadata string, - promptCacheKey string, -) (http.Header, openAIWSSessionHeaderResolution) { - headers := make(http.Header) - headers.Set("authorization", "Bearer "+token) - - sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) - if c != nil && c.Request != nil { - if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { - headers.Set("accept-language", v) - } - } - // OAuth 账号:将 apiKeyID 混入 session 标识符,防止跨用户会话碰撞。 - if account != nil && account.Type == AccountTypeOAuth { - apiKeyID := getAPIKeyIDFromContext(c) - if sessionResolution.SessionID != "" { - headers.Set("session_id", isolateOpenAISessionID(apiKeyID, sessionResolution.SessionID)) - } - if sessionResolution.ConversationID != "" { - headers.Set("conversation_id", isolateOpenAISessionID(apiKeyID, sessionResolution.ConversationID)) - } - } else { - if sessionResolution.SessionID != "" { - headers.Set("session_id", sessionResolution.SessionID) - } - if sessionResolution.ConversationID != "" { - headers.Set("conversation_id", sessionResolution.ConversationID) - } - } - if state := strings.TrimSpace(turnState); state != "" { - headers.Set(openAIWSTurnStateHeader, state) - } - if metadata := strings.TrimSpace(turnMetadata); metadata != "" { - headers.Set(openAIWSTurnMetadataHeader, metadata) - } - - if account != nil && account.Type == AccountTypeOAuth { - if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { - headers.Set("chatgpt-account-id", chatgptAccountID) - } - headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) - } - - betaValue := openAIWSBetaV2Value - if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { - betaValue = openAIWSBetaV1Value - } - headers.Set("OpenAI-Beta", betaValue) - - customUA := "" - if account != nil { - customUA = account.GetOpenAIUserAgent() - } - if strings.TrimSpace(customUA) != "" { - headers.Set("user-agent", customUA) - } else if c != nil { - if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { - headers.Set("user-agent", ua) - } - } - if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { - headers.Set("user-agent", codexCLIUserAgent) - } - if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { - headers.Set("user-agent", codexCLIUserAgent) - } - - return headers, sessionResolution -} - -func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { - // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 - // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 - payload := make(map[string]any, len(reqBody)+1) - for k, v := range reqBody { - payload[k] = v - } - - delete(payload, "background") - if _, exists := payload["stream"]; !exists { - payload["stream"] = true - } - payload["type"] = "response.create" - - // OAuth 默认保持 store=false,避免误依赖服务端历史。 - if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { - payload["store"] = false - } - return payload -} - -func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { - if len(payload) == 0 { - return - } - metadata := strings.TrimSpace(turnMetadata) - if metadata == "" { - return - } - - switch existing := payload["client_metadata"].(type) { - case map[string]any: - existing[openAIWSTurnMetadataHeader] = metadata - payload["client_metadata"] = existing - case map[string]string: - next := make(map[string]any, len(existing)+1) - for k, v := range existing { - next[k] = v - } - next[openAIWSTurnMetadataHeader] = metadata - payload["client_metadata"] = next - default: - payload["client_metadata"] = map[string]any{ - openAIWSTurnMetadataHeader: metadata, - } - } -} - -func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { - if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { - return true - } - if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { - return true - } - return false -} - -func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { - if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { - return true - } - if len(reqBody) == 0 { - return false - } - rawStore, ok := reqBody["store"] - if !ok { - return false - } - storeEnabled, ok := rawStore.(bool) - if !ok { - return false - } - return !storeEnabled -} - -func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { - if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { - return true - } - if len(reqBody) == 0 { - return false - } - storeValue := gjson.GetBytes(reqBody, "store") - if !storeValue.Exists() { - return false - } - if storeValue.Type != gjson.True && storeValue.Type != gjson.False { - return false - } - return !storeValue.Bool() -} - -func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { - if s == nil || s.cfg == nil { - return openAIWSStoreDisabledConnModeStrict - } - mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) - switch mode { - case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: - return mode - case "": - // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 - if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { - return openAIWSStoreDisabledConnModeStrict - } - return openAIWSStoreDisabledConnModeOff - default: - return openAIWSStoreDisabledConnModeStrict - } -} - -func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { - switch mode { - case openAIWSStoreDisabledConnModeOff: - return false - case openAIWSStoreDisabledConnModeAdaptive: - reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") - switch reason { - case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": - return true - default: - return false - } - default: - return true - } -} - -func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { - return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) -} - -func dropPreviousResponseIDFromRawPayloadWithDeleteFn( - payload []byte, - deleteFn func([]byte, string) ([]byte, error), -) ([]byte, bool, error) { - if len(payload) == 0 { - return payload, false, nil - } - if !gjson.GetBytes(payload, "previous_response_id").Exists() { - return payload, false, nil - } - if deleteFn == nil { - deleteFn = sjson.DeleteBytes - } - - updated := payload - for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && - gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { - next, err := deleteFn(updated, "previous_response_id") - if err != nil { - return payload, false, err - } - updated = next - } - return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil -} - -func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { - normalizedPrevID := strings.TrimSpace(previousResponseID) - if len(payload) == 0 || normalizedPrevID == "" { - return payload, nil - } - updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) - if err == nil { - return updated, nil - } - - var reqBody map[string]any - if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { - return nil, err - } - reqBody["previous_response_id"] = normalizedPrevID - rebuilt, marshalErr := json.Marshal(reqBody) - if marshalErr != nil { - return nil, marshalErr - } - return rebuilt, nil -} - -func shouldInferIngressFunctionCallOutputPreviousResponseID( - storeDisabled bool, - turn int, - signals ToolContinuationSignals, - currentPreviousResponseID string, - expectedPreviousResponseID string, -) bool { - if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput { - return false - } - if strings.TrimSpace(currentPreviousResponseID) != "" { - return false - } - if signals.HasFunctionCallOutputMissingCallID { - return false - } - // If the client already sent the actual tool-call context, treat this as - // a full replay / self-contained continuation payload rather than - // downgrading it into an inferred delta continuation. item_reference alone - // is not enough on the store=false WS path: it still needs a valid prior - // response anchor so upstream can resolve the referenced function_call. - if signals.HasToolCallContext { - return false - } - return strings.TrimSpace(expectedPreviousResponseID) != "" -} - -func alignStoreDisabledPreviousResponseID( - payload []byte, - expectedPreviousResponseID string, -) ([]byte, bool, error) { - if len(payload) == 0 { - return payload, false, nil - } - expected := strings.TrimSpace(expectedPreviousResponseID) - if expected == "" { - return payload, false, nil - } - current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") - if current == "" || current == expected { - return payload, false, nil - } - - withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) - if dropErr != nil { - return payload, false, dropErr - } - if !removed { - return payload, false, nil - } - updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) - if setErr != nil { - return payload, false, setErr - } - return updated, true, nil -} - -func cloneOpenAIWSPayloadBytes(payload []byte) []byte { - if len(payload) == 0 { - return nil - } - cloned := make([]byte, len(payload)) - copy(cloned, payload) - return cloned -} - -func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { - if items == nil { - return nil - } - cloned := make([]json.RawMessage, 0, len(items)) - for idx := range items { - cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) - } - return cloned -} - -func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { - trimmed := bytes.TrimSpace(raw) - if len(trimmed) == 0 { - return nil, errors.New("json is empty") - } - var decoded any - if err := json.Unmarshal(trimmed, &decoded); err != nil { - return nil, err - } - return json.Marshal(decoded) -} - -func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { - normalized, err := normalizeOpenAIWSJSONForCompare(raw) - if err != nil { - return bytes.TrimSpace(raw) - } - return normalized -} - -func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { - if len(payload) == 0 { - return nil, errors.New("payload is empty") - } - var decoded map[string]any - if err := json.Unmarshal(payload, &decoded); err != nil { - return nil, err - } - delete(decoded, "input") - delete(decoded, "previous_response_id") - return json.Marshal(decoded) -} - -func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { - if len(payload) == 0 { - return nil, false, nil - } - inputValue := gjson.GetBytes(payload, "input") - if !inputValue.Exists() { - return nil, false, nil - } - if inputValue.Type == gjson.JSON { - raw := strings.TrimSpace(inputValue.Raw) - if strings.HasPrefix(raw, "[") { - var items []json.RawMessage - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return nil, true, err - } - return items, true, nil - } - return []json.RawMessage{json.RawMessage(raw)}, true, nil - } - if inputValue.Type == gjson.String { - encoded, _ := json.Marshal(inputValue.String()) - return []json.RawMessage{encoded}, true, nil - } - return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil -} - -func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { - previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) - if prevErr != nil { - return false, prevErr - } - currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) - if currentErr != nil { - return false, currentErr - } - if !previousExists && !currentExists { - return true, nil - } - if !previousExists { - return len(currentItems) == 0, nil - } - if !currentExists { - return len(previousItems) == 0, nil - } - if len(currentItems) < len(previousItems) { - return false, nil - } - - for idx := range previousItems { - previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) - currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) - if !bytes.Equal(previousNormalized, currentNormalized) { - return false, nil - } - } - return true, nil -} - -func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { - if len(prefix) == 0 { - return true - } - if len(items) < len(prefix) { - return false - } - for idx := range prefix { - previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) - currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) - if !bytes.Equal(previousNormalized, currentNormalized) { - return false - } - } - return true -} - -func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool { - for _, item := range items { - if gjson.GetBytes(item, "type").String() == "function_call_output" { - return true - } - } - return false -} - -func buildOpenAIWSReplayInputSequence( - previousFullInput []json.RawMessage, - previousFullInputExists bool, - currentPayload []byte, - hasPreviousResponseID bool, -) ([]json.RawMessage, bool, error) { - currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) - if currentErr != nil { - return nil, false, currentErr - } - if !hasPreviousResponseID { - return cloneOpenAIWSRawMessages(currentItems), currentExists, nil - } - if !previousFullInputExists { - return cloneOpenAIWSRawMessages(currentItems), currentExists, nil - } - if !currentExists || len(currentItems) == 0 { - return cloneOpenAIWSRawMessages(previousFullInput), true, nil - } - if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { - return cloneOpenAIWSRawMessages(currentItems), true, nil - } - merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) - merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) - merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) - return merged, true, nil -} - -func setOpenAIWSPayloadInputSequence( - payload []byte, - fullInput []json.RawMessage, - fullInputExists bool, -) ([]byte, error) { - if !fullInputExists { - return payload, nil - } - // Preserve [] vs null semantics when input exists but is empty. - inputForMarshal := fullInput - if inputForMarshal == nil { - inputForMarshal = []json.RawMessage{} - } - inputRaw, marshalErr := json.Marshal(inputForMarshal) - if marshalErr != nil { - return nil, marshalErr - } - return sjson.SetRawBytes(payload, "input", inputRaw) -} - -func shouldKeepIngressPreviousResponseID( - previousPayload []byte, - currentPayload []byte, - lastTurnResponseID string, - hasFunctionCallOutput bool, -) (bool, string, error) { - if hasFunctionCallOutput { - return true, "has_function_call_output", nil - } - currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) - if currentPreviousResponseID == "" { - return false, "missing_previous_response_id", nil - } - expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) - if expectedPreviousResponseID == "" { - return false, "missing_last_turn_response_id", nil - } - if currentPreviousResponseID != expectedPreviousResponseID { - return false, "previous_response_id_mismatch", nil - } - if len(previousPayload) == 0 { - return false, "missing_previous_turn_payload", nil - } - - previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) - if previousComparableErr != nil { - return false, "non_input_compare_error", previousComparableErr - } - currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) - if currentComparableErr != nil { - return false, "non_input_compare_error", currentComparableErr - } - if !bytes.Equal(previousComparable, currentComparable) { - return false, "non_input_changed", nil - } - return true, "strict_incremental_ok", nil -} - -type openAIWSIngressPreviousTurnStrictState struct { - nonInputComparable []byte + return dial + 2*time.Second } -func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { - if len(payload) == 0 { - return nil, nil +func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { + if s == nil || s.cfg == nil { + return openAIWSStoreDisabledConnModeStrict } - nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) - if nonInputErr != nil { - return nil, nonInputErr + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) + switch mode { + case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: + return mode + case "": + // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 + if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + return openAIWSStoreDisabledConnModeStrict + } + return openAIWSStoreDisabledConnModeOff + default: + return openAIWSStoreDisabledConnModeStrict } - return &openAIWSIngressPreviousTurnStrictState{ - nonInputComparable: nonInputComparable, - }, nil } -func shouldKeepIngressPreviousResponseIDWithStrictState( - previousState *openAIWSIngressPreviousTurnStrictState, - currentPayload []byte, - lastTurnResponseID string, - hasFunctionCallOutput bool, -) (bool, string, error) { - if hasFunctionCallOutput { - return true, "has_function_call_output", nil - } - currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) - if currentPreviousResponseID == "" { - return false, "missing_previous_response_id", nil - } - expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) - if expectedPreviousResponseID == "" { - return false, "missing_last_turn_response_id", nil +func dropPreviousResponseIDFromRawPayloadWithDeleteFn( + payload []byte, + deleteFn func([]byte, string) ([]byte, error), +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil } - if currentPreviousResponseID != expectedPreviousResponseID { - return false, "previous_response_id_mismatch", nil + if !gjson.GetBytes(payload, "previous_response_id").Exists() { + return payload, false, nil } - if previousState == nil { - return false, "missing_previous_turn_payload", nil + if deleteFn == nil { + deleteFn = sjson.DeleteBytes } - currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) - if currentComparableErr != nil { - return false, "non_input_compare_error", currentComparableErr - } - if !bytes.Equal(previousState.nonInputComparable, currentComparable) { - return false, "non_input_changed", nil + updated := payload + for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && + gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { + next, err := deleteFn(updated, "previous_response_id") + if err != nil { + return payload, false, err + } + updated = next } - return true, "strict_incremental_ok", nil + return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil } +// forwardOpenAIWSV2 通过 WebSocket v2 协议转发单次请求到上游。 +// 生命周期:构建 payload → 解析会话状态 → 获取连接 → 发送+中继响应 → 状态存储更新。 func (s *OpenAIGatewayService) forwardOpenAIWSV2( ctx context.Context, c *gin.Context, @@ -1717,6 +355,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) } + // ─── 阶段1: 构建 WebSocket payload 并解析目标 URL ─── wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { return nil, wrapOpenAIWSFallback("build_ws_url", err) @@ -1788,6 +427,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( ) } + // ─── 阶段2: 会话状态解析与连接亲和性 ─── stateStore := s.getOpenAIWSStateStore() groupID := getOpenAIGroupIDFromContext(c) sessionHash := s.GenerateSessionHash(c, nil) @@ -1849,6 +489,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( account.ProxyID != nil && account.Proxy != nil, ) + // ─── 阶段3: 获取上游 WebSocket 连接(从池中复用或新建)─── acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) defer acquireCancel() @@ -1977,6 +618,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( return nil, err } + // ─── 阶段4: 发送请求并中继上游响应事件到客户端 ─── if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil { lease.MarkBroken() logOpenAIWSModeInfo( @@ -2366,8 +1008,9 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( }, nil } -// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 -// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。 +// ProxyResponsesWebSocketFromClient 代理客户端入站 WebSocket 的完整生命周期。 +// 处理多轮对话:每轮读取客户端消息 → 转发到上游 → 中继响应 → 更新状态。 +// 流程:参数校验 → 协议/模式选择 → 多轮对话循环(turn loop)→ 清理。 func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ctx context.Context, c *gin.Context, @@ -2393,9 +1036,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return errors.New("token is empty") } - // 预取一次 OpenAI Fast Policy settings,绑定到 ctx,让该 WS session - // 内所有帧的 evaluateOpenAIFastPolicy 调用复用同一份快照,避免每帧 - // 进入 DB / settingRepo。Trade-off 见 withOpenAIFastPolicyContext 注释。 + // ─── 阶段1: 会话初始化(策略预取、协议选择、模式路由)─── + if s.settingService != nil { if settings, err := s.settingService.GetOpenAIFastPolicySettings(ctx); err == nil && settings != nil { ctx = withOpenAIFastPolicyContext(ctx, settings) @@ -2625,6 +1267,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( }, nil } + // ─── 阶段2: 解析首帧并建立会话状态 ─── firstPayload, err := parseClientPayload(firstClientMessage) if err != nil { return err @@ -3667,262 +2310,6 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } -func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { - return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled -} - -// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 -// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 -func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( - ctx context.Context, - lease *openAIWSConnLease, - decision OpenAIWSProtocolDecision, - payload map[string]any, - previousResponseID string, - reqBody map[string]any, - account *Account, - stateStore OpenAIWSStateStore, - groupID int64, -) error { - if s == nil { - return nil - } - if lease == nil || account == nil { - logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) - return nil - } - connID := strings.TrimSpace(lease.ConnID()) - if !s.isOpenAIWSGeneratePrewarmEnabled() { - return nil - } - if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { - logOpenAIWSModeInfo( - "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", - account.ID, - connID, - normalizeOpenAIWSLogValue(string(decision.Transport)), - ) - return nil - } - if strings.TrimSpace(previousResponseID) != "" { - logOpenAIWSModeInfo( - "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", - account.ID, - connID, - truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), - ) - return nil - } - if lease.IsPrewarmed() { - logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) - return nil - } - if NeedsToolContinuation(reqBody) { - logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) - return nil - } - prewarmStart := time.Now() - logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) - - prewarmPayload := make(map[string]any, len(payload)+1) - for k, v := range payload { - prewarmPayload[k] = v - } - prewarmPayload["generate"] = false - prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) - - if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { - lease.MarkBroken() - logOpenAIWSModeInfo( - "prewarm_write_fail account_id=%d conn_id=%s cause=%s", - account.ID, - connID, - truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), - ) - return wrapOpenAIWSFallback("prewarm_write", err) - } - logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) - - prewarmResponseID := "" - prewarmEventCount := 0 - prewarmTerminalCount := 0 - for { - message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) - if readErr != nil { - lease.MarkBroken() - closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) - logOpenAIWSModeInfo( - "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", - account.ID, - connID, - closeStatus, - closeReason, - truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), - prewarmEventCount, - ) - return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) - } - - eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) - if eventType == "" { - continue - } - prewarmEventCount++ - if prewarmResponseID == "" && eventResponseID != "" { - prewarmResponseID = eventResponseID - } - if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { - logOpenAIWSModeInfo( - "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", - account.ID, - connID, - prewarmEventCount, - truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), - len(message), - ) - } - - if eventType == "error" { - errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) - s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) - errMsg := strings.TrimSpace(errMsgRaw) - if errMsg == "" { - errMsg = "OpenAI websocket prewarm error" - } - fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) - errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) - logOpenAIWSModeInfo( - "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", - account.ID, - connID, - prewarmEventCount, - truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), - canFallback, - errCode, - errType, - errMessage, - ) - lease.MarkBroken() - if canFallback { - return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) - } - return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) - } - - if isOpenAIWSTerminalEvent(eventType) { - prewarmTerminalCount++ - break - } - } - - lease.MarkPrewarmed() - if prewarmResponseID != "" && stateStore != nil { - ttl := s.openAIWSResponseStickyTTL() - logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) - stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) - } - logOpenAIWSModeInfo( - "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", - account.ID, - connID, - truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), - prewarmEventCount, - prewarmTerminalCount, - time.Since(prewarmStart).Milliseconds(), - ) - return nil -} - -func payloadAsJSON(payload map[string]any) string { - return string(payloadAsJSONBytes(payload)) -} - -func payloadAsJSONBytes(payload map[string]any) []byte { - if len(payload) == 0 { - return []byte("{}") - } - body, err := json.Marshal(payload) - if err != nil { - return []byte("{}") - } - return body -} - -func isOpenAIWSTerminalEvent(eventType string) bool { - switch strings.TrimSpace(eventType) { - case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": - return true - default: - return false - } -} - -func isOpenAIWSTokenEvent(eventType string) bool { - eventType = strings.TrimSpace(eventType) - if eventType == "" { - return false - } - switch eventType { - case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": - return false - } - if strings.Contains(eventType, ".delta") { - return true - } - if strings.HasPrefix(eventType, "response.output_text") { - return true - } - if strings.HasPrefix(eventType, "response.output") { - return true - } - return eventType == "response.completed" || eventType == "response.done" -} - -func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { - if len(message) == 0 { - return message - } - if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { - return message - } - if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { - return message - } - modelValues := gjson.GetManyBytes(message, "model", "response.model") - replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel - replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel - if !replaceModel && !replaceResponseModel { - return message - } - updated := message - if replaceModel { - if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { - updated = next - } - } - if replaceResponseModel { - if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { - updated = next - } - } - return updated -} - -func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { - if usage == nil || len(body) == 0 { - return - } - values := gjson.GetManyBytes( - body, - "usage.input_tokens", - "usage.output_tokens", - "usage.input_tokens_details.cached_tokens", - ) - usage.InputTokens = int(values[0].Int()) - usage.OutputTokens = int(values[1].Int()) - usage.CacheReadInputTokens = int(values[2].Int()) -} - func getOpenAIGroupIDFromContext(c *gin.Context) int64 { if c == nil { return 0 @@ -4027,202 +2414,3 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( } return nil, nil } - -func classifyOpenAIWSAcquireError(err error) string { - if err == nil { - return "acquire_conn" - } - var dialErr *openAIWSDialError - if errors.As(err, &dialErr) { - switch dialErr.StatusCode { - case 426: - return "upgrade_required" - case 401, 403: - return "auth_failed" - case 429: - return "upstream_rate_limited" - } - if dialErr.StatusCode >= 500 { - return "upstream_5xx" - } - return "dial_failed" - } - if errors.Is(err, errOpenAIWSConnQueueFull) { - return "conn_queue_full" - } - if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { - return "preferred_conn_unavailable" - } - if errors.Is(err, context.DeadlineExceeded) { - return "acquire_timeout" - } - return "acquire_conn" -} - -func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool { - code := strings.ToLower(strings.TrimSpace(codeRaw)) - errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) - msg := strings.ToLower(strings.TrimSpace(msgRaw)) - - if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") { - return true - } - if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") { - return true - } - if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") { - return true - } - if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) { - return true - } - return false -} - -func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) { - if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI { - return - } - if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { - return - } - s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody) -} - -func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { - code := strings.ToLower(strings.TrimSpace(codeRaw)) - errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) - msg := strings.ToLower(strings.TrimSpace(msgRaw)) - - switch code { - case "upgrade_required": - return "upgrade_required", true - case "websocket_not_supported", "websocket_unsupported": - return "ws_unsupported", true - case "websocket_connection_limit_reached": - return "ws_connection_limit_reached", true - case "invalid_encrypted_content": - return "invalid_encrypted_content", true - case "previous_response_not_found": - return "previous_response_not_found", true - } - if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) { - return "upstream_rate_limited", false - } - if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { - return "upgrade_required", true - } - if strings.Contains(errType, "upgrade") { - return "upgrade_required", true - } - if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { - return "ws_unsupported", true - } - if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { - return "ws_connection_limit_reached", true - } - if strings.Contains(msg, "invalid_encrypted_content") || - (strings.Contains(msg, "encrypted content") && strings.Contains(msg, "could not be verified")) { - return "invalid_encrypted_content", true - } - if strings.Contains(msg, "previous_response_not_found") || - (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { - return "previous_response_not_found", true - } - if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { - return "upstream_error_event", true - } - return "event_error", false -} - -func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { - if len(message) == 0 { - return "event_error", false - } - return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) -} - -func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { - code := strings.ToLower(strings.TrimSpace(codeRaw)) - errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) - switch { - case strings.Contains(errType, "invalid_request"), - strings.Contains(code, "invalid_request"), - strings.Contains(code, "bad_request"), - code == "invalid_encrypted_content", - code == "previous_response_not_found": - return http.StatusBadRequest - case strings.Contains(errType, "authentication"), - strings.Contains(code, "invalid_api_key"), - strings.Contains(code, "unauthorized"): - return http.StatusUnauthorized - case strings.Contains(errType, "permission"), - strings.Contains(code, "forbidden"): - return http.StatusForbidden - case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""): - return http.StatusTooManyRequests - default: - return http.StatusBadGateway - } -} - -func openAIWSErrorHTTPStatus(message []byte) int { - if len(message) == 0 { - return http.StatusBadGateway - } - codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) - return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) -} - -func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { - if s == nil || s.cfg == nil { - return 30 * time.Second - } - seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds - if seconds <= 0 { - return 0 - } - return time.Duration(seconds) * time.Second -} - -func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { - if s == nil || accountID <= 0 { - return false - } - cooldown := s.openAIWSFallbackCooldown() - if cooldown <= 0 { - return false - } - rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) - if !ok || rawUntil == nil { - return false - } - until, ok := rawUntil.(time.Time) - if !ok || until.IsZero() { - s.openaiWSFallbackUntil.Delete(accountID) - return false - } - if time.Now().Before(until) { - return true - } - s.openaiWSFallbackUntil.Delete(accountID) - return false -} - -func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { - if s == nil || accountID <= 0 { - return - } - cooldown := s.openAIWSFallbackCooldown() - if cooldown <= 0 { - return - } - s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) -} - -func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { - if s == nil || accountID <= 0 { - return - } - s.openaiWSFallbackUntil.Delete(accountID) -} diff --git a/backend/internal/service/openai_ws_helpers.go b/backend/internal/service/openai_ws_helpers.go new file mode 100644 index 00000000000..29839472dfd --- /dev/null +++ b/backend/internal/service/openai_ws_helpers.go @@ -0,0 +1,699 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +func truncateOpenAIWSLogValue(value string, maxLen int) string { + normalized := normalizeOpenAIWSLogValue(value) + if normalized == "-" || maxLen <= 0 { + return normalized + } + if len(normalized) <= maxLen { + return normalized + } + return normalized[:maxLen] + "..." +} + +func shouldLogOpenAIWSEvent(idx int, eventType string) bool { + if idx <= openAIWSEventLogHeadLimit { + return true + } + if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { + return true + } + if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + return true + } + return false +} + +func shouldLogOpenAIWSBufferedEvent(idx int) bool { + if idx <= openAIWSBufferLogHeadLimit { + return true + } + if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { + return true + } + return false +} + +func openAIWSEventMayContainModel(eventType string) bool { + switch eventType { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + trimmed := strings.TrimSpace(eventType) + if trimmed == eventType { + return false + } + switch trimmed { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + return false + } + } +} + +func openAIWSEventMayContainToolCalls(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { + return true + } + switch eventType { + case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": + return true + default: + return false + } +} + +func openAIWSEventShouldParseUsage(eventType string) bool { + return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" +} + +func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { + if len(message) == 0 { + return "", "", gjson.Result{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID, values[3] +} + +func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { + if len(message) == 0 { + return false + } + return bytes.Contains(message, []byte(`"tool_calls"`)) || + bytes.Contains(message, []byte(`"tool_call"`)) || + bytes.Contains(message, []byte(`"function_call"`)) +} + +func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { + if usage == nil || len(message) == 0 { + return + } + values := gjson.GetManyBytes( + message, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "", "", "" + } + values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") + return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) +} + +func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { + code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) + errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) + errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) + return code, errType, errMessage +} + +func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "-", "-", "-" + } + return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { + if len(payload) == 0 { + return "-" + } + type keySize struct { + Key string + Size int + } + sizes := make([]keySize, 0, len(payload)) + for key, value := range payload { + size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) + sizes = append(sizes, keySize{Key: key, Size: size}) + } + sort.Slice(sizes, func(i, j int) bool { + if sizes[i].Size == sizes[j].Size { + return sizes[i].Key < sizes[j].Key + } + return sizes[i].Size > sizes[j].Size + }) + + if topN <= 0 || topN > len(sizes) { + topN = len(sizes) + } + parts := make([]string, 0, topN) + for idx := 0; idx < topN; idx++ { + item := sizes[idx] + parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) + } + return strings.Join(parts, ",") +} + +func estimateOpenAIWSPayloadValueSize(value any, depth int) int { + if depth <= 0 { + return -1 + } + switch v := value.(type) { + case nil: + return 0 + case string: + return len(v) + case []byte: + return len(v) + case bool: + return 1 + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return 8 + case float32, float64: + return 8 + case map[string]any: + if len(v) == 0 { + return 2 + } + total := 2 + count := 0 + for key, item := range v { + count++ + if count > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) + if itemSize < 0 { + return -1 + } + total += len(key) + itemSize + 3 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + case []any: + if len(v) == 0 { + return 2 + } + total := 2 + limit := len(v) + if limit > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + for i := 0; i < limit; i++ { + itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) + if itemSize < 0 { + return -1 + } + total += itemSize + 1 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + default: + raw, err := json.Marshal(v) + if err != nil { + return -1 + } + if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + return len(raw) + } +} + +func openAIWSPayloadString(payload map[string]any, key string) string { + if len(payload) == 0 { + return "" + } + raw, ok := payload[key] + if !ok { + return "" + } + switch v := raw.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func openAIWSPayloadStringFromRaw(payload []byte, key string) string { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, key).String()) +} + +func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return defaultValue + } + value := gjson.GetBytes(payload, key) + if !value.Exists() { + return defaultValue + } + if value.Type != gjson.True && value.Type != gjson.False { + return defaultValue + } + return value.Bool() +} + +func openAIWSSessionHashesFromID(sessionID string) (string, string) { + return deriveOpenAISessionHashes(sessionID) +} + +func extractOpenAIWSImageURL(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if raw, ok := v["url"].(string); ok { + return strings.TrimSpace(raw) + } + } + return "" +} + +func summarizeOpenAIWSInput(input any) string { + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "-" + } + + itemCount := len(items) + textChars := 0 + imageDataURLs := 0 + imageDataURLChars := 0 + imageRemoteURLs := 0 + + handleContentItem := func(contentItem map[string]any) { + contentType, _ := contentItem["type"].(string) + switch strings.TrimSpace(contentType) { + case "input_text", "output_text", "text": + if text, ok := contentItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + handleInputItem := func(inputItem map[string]any) { + if content, ok := inputItem["content"].([]any); ok { + for _, rawContent := range content { + contentItem, ok := rawContent.(map[string]any) + if !ok { + continue + } + handleContentItem(contentItem) + } + return + } + + itemType, _ := inputItem["type"].(string) + switch strings.TrimSpace(itemType) { + case "input_text", "output_text", "text": + if text, ok := inputItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + for _, rawItem := range items { + inputItem, ok := rawItem.(map[string]any) + if !ok { + continue + } + handleInputItem(inputItem) + } + + return fmt.Sprintf( + "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", + itemCount, + textChars, + imageDataURLs, + imageDataURLChars, + imageRemoteURLs, + ) +} + +func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return + } + if _, exists := payload[key]; !exists { + return + } + delete(payload, key) + *removed = append(*removed, key) +} + +func logOpenAIWSModeInfo(format string, args ...any) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func isOpenAIWSModeDebugEnabled() bool { + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func logOpenAIWSModeDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { + if err == nil { + return + } + logger.L().Warn( + "openai.ws_bind_response_account_failed", + zap.Int64("group_id", groupID), + zap.Int64("account_id", accountID), + zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), + zap.Error(err), + ) +} + +func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled +} + +// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 +// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 +func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( + ctx context.Context, + lease *openAIWSConnLease, + decision OpenAIWSProtocolDecision, + payload map[string]any, + previousResponseID string, + reqBody map[string]any, + account *Account, + stateStore OpenAIWSStateStore, + groupID int64, +) error { + if s == nil { + return nil + } + if lease == nil || account == nil { + logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) + return nil + } + connID := strings.TrimSpace(lease.ConnID()) + if !s.isOpenAIWSGeneratePrewarmEnabled() { + return nil + } + if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", + account.ID, + connID, + normalizeOpenAIWSLogValue(string(decision.Transport)), + ) + return nil + } + if strings.TrimSpace(previousResponseID) != "" { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + return nil + } + if lease.IsPrewarmed() { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) + return nil + } + if NeedsToolContinuation(reqBody) { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) + return nil + } + prewarmStart := time.Now() + logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) + + prewarmPayload := make(map[string]any, len(payload)+1) + for k, v := range payload { + prewarmPayload[k] = v + } + prewarmPayload["generate"] = false + prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) + + if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "prewarm_write_fail account_id=%d conn_id=%s cause=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return wrapOpenAIWSFallback("prewarm_write", err) + } + logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) + + prewarmResponseID := "" + prewarmEventCount := 0 + prewarmTerminalCount := 0 + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", + account.ID, + connID, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + prewarmEventCount, + ) + return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + prewarmEventCount++ + if prewarmResponseID == "" && eventResponseID != "" { + prewarmResponseID = eventResponseID + } + if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + logOpenAIWSModeInfo( + "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + ) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "OpenAI websocket prewarm error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + lease.MarkBroken() + if canFallback { + return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) + } + return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) + } + + if isOpenAIWSTerminalEvent(eventType) { + prewarmTerminalCount++ + break + } + } + + lease.MarkPrewarmed() + if prewarmResponseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) + stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) + } + logOpenAIWSModeInfo( + "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), + prewarmEventCount, + prewarmTerminalCount, + time.Since(prewarmStart).Milliseconds(), + ) + return nil +} + +func payloadAsJSON(payload map[string]any) string { + return string(payloadAsJSONBytes(payload)) +} + +func payloadAsJSONBytes(payload map[string]any) []byte { + if len(payload) == 0 { + return []byte("{}") + } + body, err := json.Marshal(payload) + if err != nil { + return []byte("{}") + } + return body +} + +func isOpenAIWSTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func isOpenAIWSTokenEvent(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { + if len(message) == 0 { + return message + } + if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { + return message + } + if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { + return message + } + modelValues := gjson.GetManyBytes(message, "model", "response.model") + replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel + replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel + if !replaceModel && !replaceResponseModel { + return message + } + updated := message + if replaceModel { + if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { + updated = next + } + } + if replaceResponseModel { + if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { + updated = next + } + } + return updated +} + +func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { + if usage == nil || len(body) == 0 { + return + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} diff --git a/backend/internal/service/openai_ws_retry.go b/backend/internal/service/openai_ws_retry.go new file mode 100644 index 00000000000..ca8639e778e --- /dev/null +++ b/backend/internal/service/openai_ws_retry.go @@ -0,0 +1,72 @@ +package service + +import ( + "context" + "errors" + "sort" + "strings" +) + +func isOpenAIWSIngressTurnRetryable(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { + return false + } + if turnErr.wroteDownstream { + return false + } + switch turnErr.stage { + case "write_upstream", "read_upstream": + return true + default: + return false + } +} + +func openAIWSIngressTurnRetryReason(err error) string { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return "unknown" + } + if turnErr.stage == "" { + return "unknown" + } + return turnErr.stage +} + +func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { + return false + } + return !turnErr.wroteDownstream +} + +// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, +// 避免重试成功却改变原始请求语义。 +// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 +func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { + if len(payload) == 0 { + return "empty", nil + } + if attempt <= 1 { + return "full", nil + } + + removed := make([]string, 0, 2) + if attempt >= 2 { + dropOpenAIWSPayloadKey(payload, "include", &removed) + } + + if len(removed) == 0 { + return "full", nil + } + sort.Strings(removed) + return "trim_optional_fields", removed +} diff --git a/backend/internal/service/openai_ws_retry_extra_test.go b/backend/internal/service/openai_ws_retry_extra_test.go new file mode 100644 index 00000000000..3b7d08826bb --- /dev/null +++ b/backend/internal/service/openai_ws_retry_extra_test.go @@ -0,0 +1,107 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsOpenAIWSIngressTurnRetryable_NilError(t *testing.T) { + assert.False(t, isOpenAIWSIngressTurnRetryable(nil)) +} + +func TestIsOpenAIWSIngressTurnRetryable_NonTurnError(t *testing.T) { + assert.False(t, isOpenAIWSIngressTurnRetryable(errors.New("random error"))) +} + +func TestIsOpenAIWSIngressTurnRetryable_ContextCanceled(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("read_upstream", context.Canceled, false) + assert.False(t, isOpenAIWSIngressTurnRetryable(err)) +} + +func TestIsOpenAIWSIngressTurnRetryable_ContextDeadline(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("read_upstream", context.DeadlineExceeded, false) + assert.False(t, isOpenAIWSIngressTurnRetryable(err)) +} + +func TestIsOpenAIWSIngressTurnRetryable_WroteDownstream(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("read_upstream", errors.New("timeout"), true) + assert.False(t, isOpenAIWSIngressTurnRetryable(err)) +} + +func TestIsOpenAIWSIngressTurnRetryable_ReadUpstreamRetryable(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("read_upstream", errors.New("connection reset"), false) + assert.True(t, isOpenAIWSIngressTurnRetryable(err)) +} + +func TestIsOpenAIWSIngressTurnRetryable_WriteUpstreamRetryable(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("write_upstream", errors.New("broken pipe"), false) + assert.True(t, isOpenAIWSIngressTurnRetryable(err)) +} + +func TestOpenAIWSIngressTurnRetryReason_NilError(t *testing.T) { + assert.Equal(t, "unknown", openAIWSIngressTurnRetryReason(nil)) +} + +func TestOpenAIWSIngressTurnRetryReason_NonTurnError(t *testing.T) { + assert.Equal(t, "unknown", openAIWSIngressTurnRetryReason(errors.New("random"))) +} + +func TestOpenAIWSIngressTurnRetryReason_ExtractsStage(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("read_upstream", errors.New("timeout"), false) + reason := openAIWSIngressTurnRetryReason(err) + assert.Equal(t, "read_upstream", reason) +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound_True(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("previous_response_not_found", errors.New("not found"), false) + assert.True(t, isOpenAIWSIngressPreviousResponseNotFound(err)) +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound_WroteDownstream(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("previous_response_not_found", errors.New("not found"), true) + assert.False(t, isOpenAIWSIngressPreviousResponseNotFound(err)) +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound_DifferentStage(t *testing.T) { + err := wrapOpenAIWSIngressTurnError("read_upstream", errors.New("not found"), false) + assert.False(t, isOpenAIWSIngressPreviousResponseNotFound(err)) +} + +func TestApplyOpenAIWSRetryPayloadStrategy_NilPayload(t *testing.T) { + strategy, removed := applyOpenAIWSRetryPayloadStrategy(nil, 1) + assert.Equal(t, "empty", strategy) + assert.Empty(t, removed) +} + +func TestApplyOpenAIWSRetryPayloadStrategy_FirstAttempt(t *testing.T) { + payload := map[string]any{ + "model": "gpt-5", + "input": "hello", + "include": []string{"reasoning"}, + } + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 1) + assert.Equal(t, "full", strategy) + assert.Empty(t, removed) + // payload should be unchanged + assert.Contains(t, payload, "include") +} + +func TestApplyOpenAIWSRetryPayloadStrategy_SecondAttempt(t *testing.T) { + payload := map[string]any{ + "model": "gpt-5", + "input": "hello", + "include": []string{"reasoning"}, + "prompt_cache_key": "pcache_123", + } + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 2) + assert.NotEqual(t, "full", strategy) + // prompt_cache_key should be preserved + assert.Contains(t, payload, "prompt_cache_key") + // include should be removed + assert.Contains(t, removed, "include") +} diff --git a/backend/internal/service/openai_ws_session_state.go b/backend/internal/service/openai_ws_session_state.go new file mode 100644 index 00000000000..36bcb0eec8e --- /dev/null +++ b/backend/internal/service/openai_ws_session_state.go @@ -0,0 +1,595 @@ +package service + +import ( + "bytes" + "encoding/json" + "errors" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type openAIWSSessionHeaderResolution struct { + SessionID string + ConversationID string + SessionSource string + ConversationSource string +} + +func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { + resolution := openAIWSSessionHeaderResolution{ + SessionSource: "none", + ConversationSource: "none", + } + if c != nil && c.Request != nil { + if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { + resolution.SessionID = sessionID + resolution.SessionSource = "header_session_id" + } + if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { + resolution.ConversationID = conversationID + resolution.ConversationSource = "header_conversation_id" + if resolution.SessionID == "" { + resolution.SessionID = conversationID + resolution.SessionSource = "header_conversation_id" + } + } + } + + cacheKey := strings.TrimSpace(promptCacheKey) + if cacheKey != "" { + if resolution.SessionID == "" { + resolution.SessionID = cacheKey + resolution.SessionSource = "prompt_cache_key" + } + } + return resolution +} + +func (s *OpenAIGatewayService) buildOpenAIWSHeaders( + c *gin.Context, + account *Account, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + turnState string, + turnMetadata string, + promptCacheKey string, +) (http.Header, openAIWSSessionHeaderResolution) { + headers := make(http.Header) + headers.Set("authorization", "Bearer "+token) + + sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) + if c != nil && c.Request != nil { + if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { + headers.Set("accept-language", v) + } + } + // OAuth 账号:将 apiKeyID 混入 session 标识符,防止跨用户会话碰撞。 + if account != nil && account.Type == AccountTypeOAuth { + apiKeyID := getAPIKeyIDFromContext(c) + if sessionResolution.SessionID != "" { + headers.Set("session_id", isolateOpenAISessionID(apiKeyID, sessionResolution.SessionID)) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", isolateOpenAISessionID(apiKeyID, sessionResolution.ConversationID)) + } + } else { + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } + } + if state := strings.TrimSpace(turnState); state != "" { + headers.Set(openAIWSTurnStateHeader, state) + } + if metadata := strings.TrimSpace(turnMetadata); metadata != "" { + headers.Set(openAIWSTurnMetadataHeader, metadata) + } + + if account != nil && account.Type == AccountTypeOAuth { + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + headers.Set("originator", resolveOpenAIUpstreamOriginator(c, isCodexCLI)) + } + + betaValue := openAIWSBetaV2Value + if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + betaValue = openAIWSBetaV1Value + } + headers.Set("OpenAI-Beta", betaValue) + + customUA := "" + if account != nil { + customUA = account.GetOpenAIUserAgent() + } + if strings.TrimSpace(customUA) != "" { + headers.Set("user-agent", customUA) + } else if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + headers.Set("user-agent", ua) + } + } + if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + headers.Set("user-agent", codexCLIUserAgent) + } + if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { + headers.Set("user-agent", codexCLIUserAgent) + } + + return headers, sessionResolution +} + +func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { + // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 + // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 + payload := make(map[string]any, len(reqBody)+1) + for k, v := range reqBody { + payload[k] = v + } + + delete(payload, "background") + if _, exists := payload["stream"]; !exists { + payload["stream"] = true + } + payload["type"] = "response.create" + + // OAuth 默认保持 store=false,避免误依赖服务端历史。 + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + payload["store"] = false + } + return payload +} + +func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { + if len(payload) == 0 { + return + } + metadata := strings.TrimSpace(turnMetadata) + if metadata == "" { + return + } + + switch existing := payload["client_metadata"].(type) { + case map[string]any: + existing[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = existing + case map[string]string: + next := make(map[string]any, len(existing)+1) + for k, v := range existing { + next[k] = v + } + next[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = next + default: + payload["client_metadata"] = map[string]any{ + openAIWSTurnMetadataHeader: metadata, + } + } +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { + if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { + return true + } + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { + return true + } + return false +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + rawStore, ok := reqBody["store"] + if !ok { + return false + } + storeEnabled, ok := rawStore.(bool) + if !ok { + return false + } + return !storeEnabled +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + storeValue := gjson.GetBytes(reqBody, "store") + if !storeValue.Exists() { + return false + } + if storeValue.Type != gjson.True && storeValue.Type != gjson.False { + return false + } + return !storeValue.Bool() +} + +func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { + switch mode { + case openAIWSStoreDisabledConnModeOff: + return false + case openAIWSStoreDisabledConnModeAdaptive: + reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") + switch reason { + case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": + return true + default: + return false + } + default: + return true + } +} + +func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { + return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) +} + +func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { + normalizedPrevID := strings.TrimSpace(previousResponseID) + if len(payload) == 0 || normalizedPrevID == "" { + return payload, nil + } + updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) + if err == nil { + return updated, nil + } + + var reqBody map[string]any + if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { + return nil, err + } + reqBody["previous_response_id"] = normalizedPrevID + rebuilt, marshalErr := json.Marshal(reqBody) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil +} + +func shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled bool, + turn int, + signals ToolContinuationSignals, + currentPreviousResponseID string, + expectedPreviousResponseID string, +) bool { + if !storeDisabled || turn <= 1 || !signals.HasFunctionCallOutput { + return false + } + if strings.TrimSpace(currentPreviousResponseID) != "" { + return false + } + if signals.HasFunctionCallOutputMissingCallID { + return false + } + // If the client already sent the actual tool-call context, treat this as + // a full replay / self-contained continuation payload rather than + // downgrading it into an inferred delta continuation. item_reference alone + // is not enough on the store=false WS path: it still needs a valid prior + // response anchor so upstream can resolve the referenced function_call. + if signals.HasToolCallContext { + return false + } + return strings.TrimSpace(expectedPreviousResponseID) != "" +} + +func alignStoreDisabledPreviousResponseID( + payload []byte, + expectedPreviousResponseID string, +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + expected := strings.TrimSpace(expectedPreviousResponseID) + if expected == "" { + return payload, false, nil + } + current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + if current == "" || current == expected { + return payload, false, nil + } + + withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + if dropErr != nil { + return payload, false, dropErr + } + if !removed { + return payload, false, nil + } + updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, true, nil +} + +func cloneOpenAIWSPayloadBytes(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + cloned := make([]byte, len(payload)) + copy(cloned, payload) + return cloned +} + +func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { + if items == nil { + return nil + } + cloned := make([]json.RawMessage, 0, len(items)) + for idx := range items { + cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) + } + return cloned +} + +func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, errors.New("json is empty") + } + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return nil, err + } + return json.Marshal(decoded) +} + +func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { + normalized, err := normalizeOpenAIWSJSONForCompare(raw) + if err != nil { + return bytes.TrimSpace(raw) + } + return normalized +} + +func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { + if len(payload) == 0 { + return nil, errors.New("payload is empty") + } + var decoded map[string]any + if err := json.Unmarshal(payload, &decoded); err != nil { + return nil, err + } + delete(decoded, "input") + delete(decoded, "previous_response_id") + return json.Marshal(decoded) +} + +func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { + if len(payload) == 0 { + return nil, false, nil + } + inputValue := gjson.GetBytes(payload, "input") + if !inputValue.Exists() { + return nil, false, nil + } + if inputValue.Type == gjson.JSON { + raw := strings.TrimSpace(inputValue.Raw) + if strings.HasPrefix(raw, "[") { + var items []json.RawMessage + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil, true, err + } + return items, true, nil + } + return []json.RawMessage{json.RawMessage(raw)}, true, nil + } + if inputValue.Type == gjson.String { + encoded, _ := json.Marshal(inputValue.String()) + return []json.RawMessage{encoded}, true, nil + } + return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil +} + +func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { + previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) + if prevErr != nil { + return false, prevErr + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !previousExists && !currentExists { + return true, nil + } + if !previousExists { + return len(currentItems) == 0, nil + } + if !currentExists { + return len(previousItems) == 0, nil + } + if len(currentItems) < len(previousItems) { + return false, nil + } + + for idx := range previousItems { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false, nil + } + } + return true, nil +} + +func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { + if len(prefix) == 0 { + return true + } + if len(items) < len(prefix) { + return false + } + for idx := range prefix { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false + } + } + return true +} + +func openAIWSRawItemsHasFunctionCallOutput(items []json.RawMessage) bool { + for _, item := range items { + if gjson.GetBytes(item, "type").String() == "function_call_output" { + return true + } + } + return false +} + +func buildOpenAIWSReplayInputSequence( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) ([]json.RawMessage, bool, error) { + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return nil, false, currentErr + } + if !hasPreviousResponseID { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !previousFullInputExists { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !currentExists || len(currentItems) == 0 { + return cloneOpenAIWSRawMessages(previousFullInput), true, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + return cloneOpenAIWSRawMessages(currentItems), true, nil + } + merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) + merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) + merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) + return merged, true, nil +} + +func setOpenAIWSPayloadInputSequence( + payload []byte, + fullInput []json.RawMessage, + fullInputExists bool, +) ([]byte, error) { + if !fullInputExists { + return payload, nil + } + // Preserve [] vs null semantics when input exists but is empty. + inputForMarshal := fullInput + if inputForMarshal == nil { + inputForMarshal = []json.RawMessage{} + } + inputRaw, marshalErr := json.Marshal(inputForMarshal) + if marshalErr != nil { + return nil, marshalErr + } + return sjson.SetRawBytes(payload, "input", inputRaw) +} + +func shouldKeepIngressPreviousResponseID( + previousPayload []byte, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if len(previousPayload) == 0 { + return false, "missing_previous_turn_payload", nil + } + + previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) + if previousComparableErr != nil { + return false, "non_input_compare_error", previousComparableErr + } + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +type openAIWSIngressPreviousTurnStrictState struct { + nonInputComparable []byte +} + +func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { + if len(payload) == 0 { + return nil, nil + } + nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) + if nonInputErr != nil { + return nil, nonInputErr + } + return &openAIWSIngressPreviousTurnStrictState{ + nonInputComparable: nonInputComparable, + }, nil +} + +func shouldKeepIngressPreviousResponseIDWithStrictState( + previousState *openAIWSIngressPreviousTurnStrictState, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if previousState == nil { + return false, "missing_previous_turn_payload", nil + } + + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousState.nonInputComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} diff --git a/backend/internal/service/openai_ws_session_state_test.go b/backend/internal/service/openai_ws_session_state_test.go new file mode 100644 index 00000000000..8b505f5625e --- /dev/null +++ b/backend/internal/service/openai_ws_session_state_test.go @@ -0,0 +1,197 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestShouldForceNewConnOnStoreDisabled_Off(t *testing.T) { + assert.False(t, shouldForceNewConnOnStoreDisabled("off", "")) + assert.False(t, shouldForceNewConnOnStoreDisabled("off", "prewarm_timeout")) +} + +func TestShouldForceNewConnOnStoreDisabled_Default(t *testing.T) { + assert.True(t, shouldForceNewConnOnStoreDisabled("", "")) + assert.True(t, shouldForceNewConnOnStoreDisabled("always", "")) +} + +func TestShouldForceNewConnOnStoreDisabled_Adaptive(t *testing.T) { + assert.True(t, shouldForceNewConnOnStoreDisabled("adaptive", "policy_violation")) + assert.True(t, shouldForceNewConnOnStoreDisabled("adaptive", "message_too_big")) + assert.True(t, shouldForceNewConnOnStoreDisabled("adaptive", "auth_failed")) + assert.True(t, shouldForceNewConnOnStoreDisabled("adaptive", "write_request")) + assert.True(t, shouldForceNewConnOnStoreDisabled("adaptive", "write")) + assert.False(t, shouldForceNewConnOnStoreDisabled("adaptive", "")) + assert.False(t, shouldForceNewConnOnStoreDisabled("adaptive", "other_reason")) +} + +func TestShouldForceNewConnOnStoreDisabled_PrewarmPrefix(t *testing.T) { + assert.False(t, shouldForceNewConnOnStoreDisabled("adaptive", "prewarm_timeout")) + assert.False(t, shouldForceNewConnOnStoreDisabled("adaptive", "prewarm_error")) +} + +func TestCloneOpenAIWSPayloadBytes_Nil(t *testing.T) { + assert.Nil(t, cloneOpenAIWSPayloadBytes(nil)) +} + +func TestCloneOpenAIWSPayloadBytes_Empty(t *testing.T) { + assert.Nil(t, cloneOpenAIWSPayloadBytes([]byte{})) +} + +func TestCloneOpenAIWSPayloadBytes_CopiesData(t *testing.T) { + original := []byte(`{"key":"value"}`) + clone := cloneOpenAIWSPayloadBytes(original) + assert.Equal(t, original, clone) + // Mutating clone should not affect original + clone[0] = 'x' + assert.NotEqual(t, original, clone) +} + +func TestCloneOpenAIWSRawMessages_Nil(t *testing.T) { + assert.Nil(t, cloneOpenAIWSRawMessages(nil)) +} + +func TestCloneOpenAIWSRawMessages_ClonesEach(t *testing.T) { + items := []json.RawMessage{ + json.RawMessage(`{"a":1}`), + json.RawMessage(`{"b":2}`), + } + clone := cloneOpenAIWSRawMessages(items) + require.Len(t, clone, 2) + assert.Equal(t, items[0], clone[0]) + // Mutating clone should not affect original + clone[0][0] = 'x' + assert.NotEqual(t, items[0], clone[0]) +} + +func TestNormalizeOpenAIWSJSONForCompare_Valid(t *testing.T) { + input := []byte(` {"b":2,"a":1} `) + result, err := normalizeOpenAIWSJSONForCompare(input) + require.NoError(t, err) + // Should produce canonical JSON (sorted keys via marshal/unmarshal) + assert.Contains(t, string(result), `"a"`) + assert.Contains(t, string(result), `"b"`) +} + +func TestNormalizeOpenAIWSJSONForCompare_EmptyInput(t *testing.T) { + _, err := normalizeOpenAIWSJSONForCompare(nil) + assert.Error(t, err) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(" ")) + assert.Error(t, err) +} + +func TestNormalizeOpenAIWSJSONForCompare_InvalidJSON(t *testing.T) { + _, err := normalizeOpenAIWSJSONForCompare([]byte(`not json`)) + assert.Error(t, err) +} + +func TestNormalizeOpenAIWSJSONForCompareOrRaw_FallbackOnError(t *testing.T) { + input := []byte(`not json`) + result := normalizeOpenAIWSJSONForCompareOrRaw(input) + assert.Equal(t, []byte("not json"), result) +} + +func TestNormalizeOpenAIWSJSONForCompareOrRaw_NormalizesValid(t *testing.T) { + input := []byte(` {"key": "value"} `) + result := normalizeOpenAIWSJSONForCompareOrRaw(input) + assert.NotEqual(t, input, result) + assert.Contains(t, string(result), `"key"`) +} + +func TestOpenAIWSRawItemsHasPrefix_EmptyPrefix(t *testing.T) { + items := []json.RawMessage{json.RawMessage(`{"a":1}`)} + assert.True(t, openAIWSRawItemsHasPrefix(items, nil)) + assert.True(t, openAIWSRawItemsHasPrefix(items, []json.RawMessage{})) +} + +func TestOpenAIWSRawItemsHasPrefix_PrefixLongerThanItems(t *testing.T) { + items := []json.RawMessage{json.RawMessage(`{"a":1}`)} + prefix := []json.RawMessage{json.RawMessage(`{"a":1}`), json.RawMessage(`{"b":2}`)} + assert.False(t, openAIWSRawItemsHasPrefix(items, prefix)) +} + +func TestOpenAIWSRawItemsHasPrefix_MatchingPrefix(t *testing.T) { + items := []json.RawMessage{ + json.RawMessage(`{"a":1}`), + json.RawMessage(`{"b":2}`), + json.RawMessage(`{"c":3}`), + } + prefix := []json.RawMessage{ + json.RawMessage(`{"a":1}`), + json.RawMessage(`{"b":2}`), + } + assert.True(t, openAIWSRawItemsHasPrefix(items, prefix)) +} + +func TestOpenAIWSRawItemsHasFunctionCallOutput_None(t *testing.T) { + items := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + } + assert.False(t, openAIWSRawItemsHasFunctionCallOutput(items)) +} + +func TestOpenAIWSRawItemsHasFunctionCallOutput_HasOne(t *testing.T) { + items := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + json.RawMessage(`{"type":"function_call_output","call_id":"call_123","output":"result"}`), + } + assert.True(t, openAIWSRawItemsHasFunctionCallOutput(items)) +} + +func TestShouldInferIngressFunctionCallOutputPreviousResponseID_Extended(t *testing.T) { + tests := []struct { + name string + storeDisabled bool + turn int + hasFunctionCallOutput bool + hasToolCallContext bool + currentPrevID string + expectedPrevID string + want bool + }{ + {"all conditions met", true, 2, true, false, "", "resp_123", true}, + {"store not disabled", false, 2, true, false, "", "resp_123", false}, + {"first turn", true, 1, true, false, "", "resp_123", false}, + {"no function call output", true, 2, false, false, "", "resp_123", false}, + {"has tool call context", true, 2, true, true, "", "resp_123", false}, + {"already has previous_response_id", true, 2, true, false, "resp_existing", "resp_123", false}, + {"empty expected ID", true, 2, true, false, "", "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signals := ToolContinuationSignals{ + HasFunctionCallOutput: tt.hasFunctionCallOutput, + HasToolCallContext: tt.hasToolCallContext, + } + got := shouldInferIngressFunctionCallOutputPreviousResponseID( + tt.storeDisabled, tt.turn, signals, tt.currentPrevID, tt.expectedPrevID, + ) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestSetOpenAIWSTurnMetadata_NilPayload(t *testing.T) { + // Should not panic + setOpenAIWSTurnMetadata(nil, `{"key":"value"}`) +} + +func TestSetOpenAIWSTurnMetadata_EmptyMetadata(t *testing.T) { + payload := map[string]any{"model": "test"} + setOpenAIWSTurnMetadata(payload, "") + _, exists := payload["client_metadata"] + assert.False(t, exists) +} + +func TestSetOpenAIWSTurnMetadata_ValidMetadata(t *testing.T) { + payload := map[string]any{"model": "test"} + setOpenAIWSTurnMetadata(payload, `{"session_id":"sess_123"}`) + _, exists := payload["client_metadata"] + assert.True(t, exists) +} diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go index 68c1d9ddea4..5f9a5618193 100644 --- a/backend/internal/service/ops_settings.go +++ b/backend/internal/service/ops_settings.go @@ -4,9 +4,10 @@ import ( "context" "encoding/json" "errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) const ( diff --git a/backend/internal/service/setting_batch.go b/backend/internal/service/setting_batch.go new file mode 100644 index 00000000000..f09ca73a663 --- /dev/null +++ b/backend/internal/service/setting_batch.go @@ -0,0 +1,1111 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +// GetAllSettings 获取所有系统设置 +func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { + settings, err := s.settingRepo.GetAll(ctx) + if err != nil { + return nil, fmt.Errorf("get all settings: %w", err) + } + + return s.parseSettings(settings), nil +} + +// GetFrontendURL 获取前端基础URL(数据库优先,fallback 到配置文件) +func (s *SettingService) GetFrontendURL(ctx context.Context) string { + val, err := s.settingRepo.GetValue(ctx, SettingKeyFrontendURL) + if err == nil && strings.TrimSpace(val) != "" { + return strings.TrimSpace(val) + } + return s.cfg.Server.FrontendURL +} + +// GetPublicSettings 获取公开设置(无需登录) +func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) { + keys := []string{ + SettingKeyRegistrationEnabled, + SettingKeyEmailVerifyEnabled, + SettingKeyForceEmailOnThirdPartySignup, + SettingKeyRegistrationEmailSuffixWhitelist, + SettingKeyPromoCodeEnabled, + SettingKeyPasswordResetEnabled, + SettingKeyInvitationCodeEnabled, + SettingKeyTotpEnabled, + SettingKeyLoginAgreementEnabled, + SettingKeyLoginAgreementMode, + SettingKeyLoginAgreementUpdatedAt, + SettingKeyLoginAgreementDocuments, + SettingKeyTurnstileEnabled, + SettingKeyTurnstileSiteKey, + SettingKeySiteName, + SettingKeySiteLogo, + SettingKeySiteSubtitle, + SettingKeyAPIBaseURL, + SettingKeyContactInfo, + SettingKeyDocURL, + SettingKeyHomeContent, + SettingKeyHideCcsImportButton, + SettingKeyPurchaseSubscriptionEnabled, + SettingKeyPurchaseSubscriptionURL, + SettingKeyTableDefaultPageSize, + SettingKeyTablePageSizeOptions, + SettingKeyCustomMenuItems, + SettingKeyCustomEndpoints, + SettingKeyLinuxDoConnectEnabled, + SettingKeyWeChatConnectEnabled, + SettingKeyWeChatConnectAppID, + SettingKeyWeChatConnectAppSecret, + SettingKeyWeChatConnectOpenAppID, + SettingKeyWeChatConnectOpenAppSecret, + SettingKeyWeChatConnectMPAppID, + SettingKeyWeChatConnectMPAppSecret, + SettingKeyWeChatConnectMobileAppID, + SettingKeyWeChatConnectMobileAppSecret, + SettingKeyWeChatConnectOpenEnabled, + SettingKeyWeChatConnectMPEnabled, + SettingKeyWeChatConnectMobileEnabled, + SettingKeyWeChatConnectMode, + SettingKeyWeChatConnectScopes, + SettingKeyWeChatConnectRedirectURL, + SettingKeyWeChatConnectFrontendRedirectURL, + SettingKeyBackendModeEnabled, + SettingPaymentEnabled, + SettingKeyOIDCConnectEnabled, + SettingKeyOIDCConnectProviderName, + SettingKeyGitHubOAuthEnabled, + SettingKeyGitHubOAuthClientID, + SettingKeyGitHubOAuthClientSecret, + SettingKeyGoogleOAuthEnabled, + SettingKeyGoogleOAuthClientID, + SettingKeyGoogleOAuthClientSecret, + SettingKeyBalanceLowNotifyEnabled, + SettingKeyBalanceLowNotifyThreshold, + SettingKeyBalanceLowNotifyRechargeURL, + SettingKeyAccountQuotaNotifyEnabled, + SettingKeyChannelMonitorEnabled, + SettingKeyChannelMonitorDefaultIntervalSeconds, + SettingKeyAvailableChannelsEnabled, + SettingKeyAffiliateEnabled, + SettingKeyRiskControlEnabled, + } + + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get public settings: %w", err) + } + + linuxDoEnabled := false + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + linuxDoEnabled = raw == "true" + } else { + linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled + } + oidcEnabled := false + if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { + oidcEnabled = raw == "true" + } else { + oidcEnabled = s.cfg != nil && s.cfg.OIDC.Enabled + } + oidcProviderName := strings.TrimSpace(settings[SettingKeyOIDCConnectProviderName]) + if oidcProviderName == "" && s.cfg != nil { + oidcProviderName = strings.TrimSpace(s.cfg.OIDC.ProviderName) + } + if oidcProviderName == "" { + oidcProviderName = "OIDC" + } + gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github") + googleEnabled := s.emailOAuthPublicEnabled(settings, "google") + weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings) + + // Password reset requires email verification to be enabled + emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" + passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true" + registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist( + settings[SettingKeyRegistrationEmailSuffixWhitelist], + ) + tableDefaultPageSize, tablePageSizeOptions := parseTablePreferences( + settings[SettingKeyTableDefaultPageSize], + settings[SettingKeyTablePageSizeOptions], + ) + loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments]) + loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt]) + if loginAgreementUpdatedAt == "" { + loginAgreementUpdatedAt = defaultLoginAgreementDate + } + + var balanceLowNotifyThreshold float64 + if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { + balanceLowNotifyThreshold = v + } + + return &PublicSettings{ + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", + RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: passwordResetEnabled, + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true" && len(loginAgreementDocuments) > 0, + LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]), + LoginAgreementUpdatedAt: loginAgreementUpdatedAt, + LoginAgreementRevision: buildLoginAgreementRevision(loginAgreementUpdatedAt, loginAgreementDocuments), + LoginAgreementDocuments: loginAgreementDocuments, + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + TableDefaultPageSize: tableDefaultPageSize, + TablePageSizeOptions: tablePageSizeOptions, + CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], + LinuxDoOAuthEnabled: linuxDoEnabled, + WeChatOAuthEnabled: weChatEnabled, + WeChatOAuthOpenEnabled: weChatOpenEnabled, + WeChatOAuthMPEnabled: weChatMPEnabled, + WeChatOAuthMobileEnabled: weChatMobileEnabled, + BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", + PaymentEnabled: settings[SettingPaymentEnabled] == "true", + OIDCOAuthEnabled: oidcEnabled, + OIDCOAuthProviderName: oidcProviderName, + GitHubOAuthEnabled: gitHubEnabled, + GoogleOAuthEnabled: googleEnabled, + BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true", + AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true", + BalanceLowNotifyThreshold: balanceLowNotifyThreshold, + BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL], + + ChannelMonitorEnabled: !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]), + ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]), + + AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true", + + AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true", + + RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true", + }, nil +} + +// UpdateSettings 更新系统设置 +func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { + return err + } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write. +func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error { + updates, err := s.buildSystemSettingsUpdates(ctx, settings) + if err != nil { + return err + } + + authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults) + if err != nil { + return err + } + for key, value := range authSourceUpdates { + updates[key] = value + } + + err = s.settingRepo.SetMultiple(ctx, updates) + if err == nil { + s.refreshCachedSettings(settings) + } + return err +} + +func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) { + if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { + return nil, err + } + normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) + } + if normalizedWhitelist == nil { + normalizedWhitelist = []string{} + } + settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist + alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled) + if err != nil { + return nil, err + } + wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled) + if err != nil { + return nil, err + } + settings.PaymentVisibleMethodAlipaySource = alipaySource + settings.PaymentVisibleMethodWxpaySource = wxpaySource + settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID) + settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret) + settings.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppID, settings.WeChatConnectAppID)) + settings.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppSecret, settings.WeChatConnectAppSecret)) + settings.WeChatConnectMode = normalizeWeChatConnectStoredMode( + settings.WeChatConnectOpenEnabled, + settings.WeChatConnectMPEnabled, + settings.WeChatConnectMobileEnabled, + settings.WeChatConnectMode, + ) + settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode) + settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL) + settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL) + if settings.WeChatConnectFrontendRedirectURL == "" { + settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend + } + settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL) + settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL) + if settings.GitHubOAuthFrontendRedirectURL == "" { + settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend + } + settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL) + settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL) + if settings.GoogleOAuthFrontendRedirectURL == "" { + settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend + } + + updates := make(map[string]string) + + // 注册设置 + updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) + updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) + registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) + if err != nil { + return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err) + } + updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) + updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) + updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) + updates[SettingKeyFrontendURL] = settings.FrontendURL + updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) + updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled) + settings.LoginAgreementMode = normalizeLoginAgreementMode(settings.LoginAgreementMode) + settings.LoginAgreementUpdatedAt = strings.TrimSpace(settings.LoginAgreementUpdatedAt) + if settings.LoginAgreementUpdatedAt == "" { + settings.LoginAgreementUpdatedAt = defaultLoginAgreementDate + } + loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(settings.LoginAgreementDocuments) + if err != nil { + return nil, err + } + updates[SettingKeyLoginAgreementEnabled] = strconv.FormatBool(settings.LoginAgreementEnabled) + updates[SettingKeyLoginAgreementMode] = settings.LoginAgreementMode + updates[SettingKeyLoginAgreementUpdatedAt] = settings.LoginAgreementUpdatedAt + updates[SettingKeyLoginAgreementDocuments] = loginAgreementDocumentsJSON + + // 邮件服务设置(只有非空才更新密码) + updates[SettingKeySMTPHost] = settings.SMTPHost + updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort) + updates[SettingKeySMTPUsername] = settings.SMTPUsername + if settings.SMTPPassword != "" { + updates[SettingKeySMTPPassword] = settings.SMTPPassword + } + updates[SettingKeySMTPFrom] = settings.SMTPFrom + updates[SettingKeySMTPFromName] = settings.SMTPFromName + updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS) + + // Cloudflare Turnstile 设置(只有非空才更新密钥) + updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled) + updates[SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey + if settings.TurnstileSecretKey != "" { + updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey + } + + // LinuxDo Connect OAuth 登录 + updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled) + updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID + updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL + if settings.LinuxDoConnectClientSecret != "" { + updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret + } + + // Generic OIDC OAuth 登录 + updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled) + updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName + updates[SettingKeyOIDCConnectClientID] = settings.OIDCConnectClientID + updates[SettingKeyOIDCConnectIssuerURL] = settings.OIDCConnectIssuerURL + updates[SettingKeyOIDCConnectDiscoveryURL] = settings.OIDCConnectDiscoveryURL + updates[SettingKeyOIDCConnectAuthorizeURL] = settings.OIDCConnectAuthorizeURL + updates[SettingKeyOIDCConnectTokenURL] = settings.OIDCConnectTokenURL + updates[SettingKeyOIDCConnectUserInfoURL] = settings.OIDCConnectUserInfoURL + updates[SettingKeyOIDCConnectJWKSURL] = settings.OIDCConnectJWKSURL + updates[SettingKeyOIDCConnectScopes] = settings.OIDCConnectScopes + updates[SettingKeyOIDCConnectRedirectURL] = settings.OIDCConnectRedirectURL + updates[SettingKeyOIDCConnectFrontendRedirectURL] = settings.OIDCConnectFrontendRedirectURL + updates[SettingKeyOIDCConnectTokenAuthMethod] = settings.OIDCConnectTokenAuthMethod + updates[SettingKeyOIDCConnectUsePKCE] = strconv.FormatBool(settings.OIDCConnectUsePKCE) + updates[SettingKeyOIDCConnectValidateIDToken] = strconv.FormatBool(settings.OIDCConnectValidateIDToken) + updates[SettingKeyOIDCConnectAllowedSigningAlgs] = settings.OIDCConnectAllowedSigningAlgs + updates[SettingKeyOIDCConnectClockSkewSeconds] = strconv.Itoa(settings.OIDCConnectClockSkewSeconds) + updates[SettingKeyOIDCConnectRequireEmailVerified] = strconv.FormatBool(settings.OIDCConnectRequireEmailVerified) + updates[SettingKeyOIDCConnectUserInfoEmailPath] = settings.OIDCConnectUserInfoEmailPath + updates[SettingKeyOIDCConnectUserInfoIDPath] = settings.OIDCConnectUserInfoIDPath + updates[SettingKeyOIDCConnectUserInfoUsernamePath] = settings.OIDCConnectUserInfoUsernamePath + if settings.OIDCConnectClientSecret != "" { + updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret + } + + // GitHub / Google 邮箱快捷登录 + updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled) + updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID) + updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL + updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL + if settings.GitHubOAuthClientSecret != "" { + updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret) + } + updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled) + updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID) + updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL + updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL + if settings.GoogleOAuthClientSecret != "" { + updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret) + } + + // WeChat Connect OAuth 登录 + updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled) + updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID + updates[SettingKeyWeChatConnectOpenAppID] = settings.WeChatConnectOpenAppID + updates[SettingKeyWeChatConnectMPAppID] = settings.WeChatConnectMPAppID + updates[SettingKeyWeChatConnectMobileAppID] = settings.WeChatConnectMobileAppID + updates[SettingKeyWeChatConnectOpenEnabled] = strconv.FormatBool(settings.WeChatConnectOpenEnabled) + updates[SettingKeyWeChatConnectMPEnabled] = strconv.FormatBool(settings.WeChatConnectMPEnabled) + updates[SettingKeyWeChatConnectMobileEnabled] = strconv.FormatBool(settings.WeChatConnectMobileEnabled) + updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode + updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes + updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL + updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL + if settings.WeChatConnectAppSecret != "" { + updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret + } + if settings.WeChatConnectOpenAppSecret != "" { + updates[SettingKeyWeChatConnectOpenAppSecret] = settings.WeChatConnectOpenAppSecret + } + if settings.WeChatConnectMPAppSecret != "" { + updates[SettingKeyWeChatConnectMPAppSecret] = settings.WeChatConnectMPAppSecret + } + if settings.WeChatConnectMobileAppSecret != "" { + updates[SettingKeyWeChatConnectMobileAppSecret] = settings.WeChatConnectMobileAppSecret + } + + // OEM设置 + updates[SettingKeySiteName] = settings.SiteName + updates[SettingKeySiteLogo] = settings.SiteLogo + updates[SettingKeySiteSubtitle] = settings.SiteSubtitle + updates[SettingKeyAPIBaseURL] = settings.APIBaseURL + updates[SettingKeyContactInfo] = settings.ContactInfo + updates[SettingKeyDocURL] = settings.DocURL + updates[SettingKeyHomeContent] = settings.HomeContent + updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) + updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) + updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) + tableDefaultPageSize, tablePageSizeOptions := normalizeTablePreferences( + settings.TableDefaultPageSize, + settings.TablePageSizeOptions, + ) + updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize) + tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions) + if err != nil { + return nil, fmt.Errorf("marshal table page size options: %w", err) + } + updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON) + updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems + updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints + + // 默认配置 + updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) + updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate) + updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64) + if settings.AffiliateRebateFreezeHours < 0 { + settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault + } + if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax { + settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax + } + updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours) + if settings.AffiliateRebateDurationDays < 0 { + settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault + } + if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax { + settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax + } + updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays) + if settings.AffiliateRebatePerInviteeCap < 0 { + settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault + } + updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64) + updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) + defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) + if err != nil { + return nil, fmt.Errorf("marshal default subscriptions: %w", err) + } + updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) + + // Model fallback configuration + updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback) + updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic + updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI + updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini + updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity + + // Identity patch configuration (Claude -> Gemini) + updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch) + updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt + + // Ops monitoring (vNext) + updates[SettingKeyOpsMonitoringEnabled] = strconv.FormatBool(settings.OpsMonitoringEnabled) + updates[SettingKeyOpsRealtimeMonitoringEnabled] = strconv.FormatBool(settings.OpsRealtimeMonitoringEnabled) + updates[SettingKeyOpsQueryModeDefault] = string(ParseOpsQueryMode(settings.OpsQueryModeDefault)) + if settings.OpsMetricsIntervalSeconds > 0 { + updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) + } + + // Channel monitor feature switch + updates[SettingKeyChannelMonitorEnabled] = strconv.FormatBool(settings.ChannelMonitorEnabled) + if v := clampChannelMonitorInterval(settings.ChannelMonitorDefaultIntervalSeconds); v > 0 { + updates[SettingKeyChannelMonitorDefaultIntervalSeconds] = strconv.Itoa(v) + } + + // Available channels feature switch + updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled) + + // Affiliate (邀请返利) feature switch + updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled) + + // 风控中心功能开关 + updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled) + + // Claude Code version check + updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion + updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion + + // 分组隔离 + updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) + + // Backend Mode + updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled) + + // Gateway forwarding behavior + updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) + updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) + updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) + updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection) + updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl) + updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion) + updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource + updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource + updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) + updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled) + updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled) + + // Balance low notification + updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) + updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64) + updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL + updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) + updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) + + return updates, nil +} + +func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) { + if settings == nil { + return nil, nil + } + + for _, subscriptions := range [][]DefaultSubscriptionSetting{ + settings.Email.Subscriptions, + settings.LinuxDo.Subscriptions, + settings.OIDC.Subscriptions, + settings.WeChat.Subscriptions, + settings.GitHub.Subscriptions, + settings.Google.Subscriptions, + } { + if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { + return nil, err + } + } + + updates := make(map[string]string, 31) + writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) + writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) + writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) + writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) + writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub) + writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google) + updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) + return updates, nil +} + +// parseSettings 解析设置到结构体 +func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { + emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" + loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments]) + loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt]) + if loginAgreementUpdatedAt == "" { + loginAgreementUpdatedAt = defaultLoginAgreementDate + } + result := &SystemSettings{ + RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", + EmailVerifyEnabled: emailVerifyEnabled, + RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]), + PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 + PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", + FrontendURL: settings[SettingKeyFrontendURL], + InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", + TotpEnabled: settings[SettingKeyTotpEnabled] == "true", + LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true", + LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]), + LoginAgreementUpdatedAt: loginAgreementUpdatedAt, + LoginAgreementDocuments: loginAgreementDocuments, + SMTPHost: settings[SettingKeySMTPHost], + SMTPUsername: settings[SettingKeySMTPUsername], + SMTPFrom: settings[SettingKeySMTPFrom], + SMTPFromName: settings[SettingKeySMTPFromName], + SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", + SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", + TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", + TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], + TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", + SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), + SiteLogo: settings[SettingKeySiteLogo], + SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), + APIBaseURL: settings[SettingKeyAPIBaseURL], + ContactInfo: settings[SettingKeyContactInfo], + DocURL: settings[SettingKeyDocURL], + HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", + PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", + PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + CustomMenuItems: settings[SettingKeyCustomMenuItems], + CustomEndpoints: settings[SettingKeyCustomEndpoints], + BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", + } + result.TableDefaultPageSize, result.TablePageSizeOptions = parseTablePreferences( + settings[SettingKeyTableDefaultPageSize], + settings[SettingKeyTablePageSizeOptions], + ) + + // 解析整数类型 + if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil { + result.SMTPPort = port + } else { + result.SMTPPort = 587 + } + + if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil { + result.DefaultConcurrency = concurrency + } else { + result.DefaultConcurrency = s.cfg.Default.UserConcurrency + } + + if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 { + result.DefaultUserRPMLimit = rpm + } + + // 解析浮点数类型 + if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil { + result.DefaultBalance = balance + } else { + result.DefaultBalance = s.cfg.Default.UserBalance + } + if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil { + result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate) + } else { + result.AffiliateRebateRate = AffiliateRebateRateDefault + } + if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 { + if freezeHours > AffiliateRebateFreezeHoursMax { + freezeHours = AffiliateRebateFreezeHoursMax + } + result.AffiliateRebateFreezeHours = freezeHours + } + if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 { + if durationDays > AffiliateRebateDurationDaysMax { + durationDays = AffiliateRebateDurationDaysMax + } + result.AffiliateRebateDurationDays = durationDays + } + if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 { + result.AffiliateRebatePerInviteeCap = perInviteeCap + } + result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) + + // 敏感信息直接返回,方便测试连接时使用 + result.SMTPPassword = settings[SettingKeySMTPPassword] + result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] + + // LinuxDo Connect 设置: + // - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭) + // - 支持在后台“系统设置”中覆盖并持久化(存储于 DB) + linuxDoBase := config.LinuxDoConnectConfig{} + if s.cfg != nil { + linuxDoBase = s.cfg.LinuxDo + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + result.LinuxDoConnectEnabled = raw == "true" + } else { + result.LinuxDoConnectEnabled = linuxDoBase.Enabled + } + + if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectClientID = strings.TrimSpace(v) + } else { + result.LinuxDoConnectClientID = linuxDoBase.ClientID + } + + if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectRedirectURL = strings.TrimSpace(v) + } else { + result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL + } + + result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret]) + if result.LinuxDoConnectClientSecret == "" { + result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret) + } + result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" + + // Generic OIDC 设置: + // - 兼容 config.yaml/env + // - 支持后台系统设置覆盖并持久化(存储于 DB) + oidcBase := config.OIDCConnectConfig{} + if s.cfg != nil { + oidcBase = s.cfg.OIDC + } + + if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { + result.OIDCConnectEnabled = raw == "true" + } else { + result.OIDCConnectEnabled = oidcBase.Enabled + } + + if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectProviderName = strings.TrimSpace(v) + } else { + result.OIDCConnectProviderName = strings.TrimSpace(oidcBase.ProviderName) + } + if result.OIDCConnectProviderName == "" { + result.OIDCConnectProviderName = "OIDC" + } + + if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectClientID = strings.TrimSpace(v) + } else { + result.OIDCConnectClientID = strings.TrimSpace(oidcBase.ClientID) + } + if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectIssuerURL = strings.TrimSpace(v) + } else { + result.OIDCConnectIssuerURL = strings.TrimSpace(oidcBase.IssuerURL) + } + if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectDiscoveryURL = strings.TrimSpace(v) + } else { + result.OIDCConnectDiscoveryURL = strings.TrimSpace(oidcBase.DiscoveryURL) + } + if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectAuthorizeURL = strings.TrimSpace(v) + } else { + result.OIDCConnectAuthorizeURL = strings.TrimSpace(oidcBase.AuthorizeURL) + } + if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectTokenURL = strings.TrimSpace(v) + } else { + result.OIDCConnectTokenURL = strings.TrimSpace(oidcBase.TokenURL) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectUserInfoURL = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoURL = strings.TrimSpace(oidcBase.UserInfoURL) + } + if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectJWKSURL = strings.TrimSpace(v) + } else { + result.OIDCConnectJWKSURL = strings.TrimSpace(oidcBase.JWKSURL) + } + if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectScopes = strings.TrimSpace(v) + } else { + result.OIDCConnectScopes = strings.TrimSpace(oidcBase.Scopes) + } + if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectRedirectURL = strings.TrimSpace(v) + } else { + result.OIDCConnectRedirectURL = strings.TrimSpace(oidcBase.RedirectURL) + } + if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(v) + } else { + result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(oidcBase.FrontendRedirectURL) + } + if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(v)) + } else { + result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(oidcBase.TokenAuthMethod)) + } + if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { + result.OIDCConnectUsePKCE = raw == "true" + } else { + result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase) + } + if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { + result.OIDCConnectValidateIDToken = raw == "true" + } else { + result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase) + } + if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { + result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) + } else { + result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(oidcBase.AllowedSigningAlgs) + } + clockSkewSet := false + if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" { + if parsed, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil { + result.OIDCConnectClockSkewSeconds = parsed + clockSkewSet = true + } + } + if !clockSkewSet { + result.OIDCConnectClockSkewSeconds = oidcBase.ClockSkewSeconds + } + if !clockSkewSet && result.OIDCConnectClockSkewSeconds == 0 { + result.OIDCConnectClockSkewSeconds = 120 + } + if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok { + result.OIDCConnectRequireEmailVerified = raw == "true" + } else { + result.OIDCConnectRequireEmailVerified = oidcBase.RequireEmailVerified + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok { + result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(oidcBase.UserInfoEmailPath) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok { + result.OIDCConnectUserInfoIDPath = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoIDPath = strings.TrimSpace(oidcBase.UserInfoIDPath) + } + if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok { + result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(v) + } else { + result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(oidcBase.UserInfoUsernamePath) + } + result.OIDCConnectClientSecret = strings.TrimSpace(settings[SettingKeyOIDCConnectClientSecret]) + if result.OIDCConnectClientSecret == "" { + result.OIDCConnectClientSecret = strings.TrimSpace(oidcBase.ClientSecret) + } + result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" + + gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github") + result.GitHubOAuthEnabled = gitHubEffective.Enabled + result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID) + result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret) + result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != "" + result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL) + result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL) + + googleEffective := s.effectiveEmailOAuthConfig(settings, "google") + result.GoogleOAuthEnabled = googleEffective.Enabled + result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID) + result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret) + result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != "" + result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL) + result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL) + + // WeChat Connect 设置: + // - 优先读取 DB 系统设置 + // - 缺失时回退到 config/env,保持升级兼容 + weChatEffective := s.effectiveWeChatConnectOAuthConfig(settings) + result.WeChatConnectEnabled = weChatEffective.Enabled + result.WeChatConnectAppID = weChatEffective.LegacyAppID + result.WeChatConnectAppSecret = weChatEffective.LegacyAppSecret + result.WeChatConnectAppSecretConfigured = weChatEffective.LegacyAppSecret != "" + result.WeChatConnectOpenAppID = weChatEffective.OpenAppID + result.WeChatConnectOpenAppSecret = weChatEffective.OpenAppSecret + result.WeChatConnectOpenAppSecretConfigured = weChatEffective.OpenAppSecret != "" + result.WeChatConnectMPAppID = weChatEffective.MPAppID + result.WeChatConnectMPAppSecret = weChatEffective.MPAppSecret + result.WeChatConnectMPAppSecretConfigured = weChatEffective.MPAppSecret != "" + result.WeChatConnectMobileAppID = weChatEffective.MobileAppID + result.WeChatConnectMobileAppSecret = weChatEffective.MobileAppSecret + result.WeChatConnectMobileAppSecretConfigured = weChatEffective.MobileAppSecret != "" + result.WeChatConnectOpenEnabled = weChatEffective.OpenEnabled + result.WeChatConnectMPEnabled = weChatEffective.MPEnabled + result.WeChatConnectMobileEnabled = weChatEffective.MobileEnabled + result.WeChatConnectMode = weChatEffective.Mode + result.WeChatConnectScopes = weChatEffective.Scopes + result.WeChatConnectRedirectURL = weChatEffective.RedirectURL + result.WeChatConnectFrontendRedirectURL = weChatEffective.FrontendRedirectURL + + // Model fallback settings + result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" + result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") + result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o") + result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro") + result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro") + + // Identity patch settings (default: enabled, to preserve existing behavior) + if v, ok := settings[SettingKeyEnableIdentityPatch]; ok && v != "" { + result.EnableIdentityPatch = v == "true" + } else { + result.EnableIdentityPatch = true + } + result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt] + + // Ops monitoring settings (default: enabled, fail-open) + result.OpsMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsMonitoringEnabled]) + result.OpsRealtimeMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsRealtimeMonitoringEnabled]) + result.OpsQueryModeDefault = string(ParseOpsQueryMode(settings[SettingKeyOpsQueryModeDefault])) + result.OpsMetricsIntervalSeconds = 60 + if raw := strings.TrimSpace(settings[SettingKeyOpsMetricsIntervalSeconds]); raw != "" { + if v, err := strconv.Atoi(raw); err == nil { + if v < 60 { + v = 60 + } + if v > 3600 { + v = 3600 + } + result.OpsMetricsIntervalSeconds = v + } + } + + // Channel monitor feature (default: enabled, 60s) + result.ChannelMonitorEnabled = !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]) + result.ChannelMonitorDefaultIntervalSeconds = parseChannelMonitorInterval( + settings[SettingKeyChannelMonitorDefaultIntervalSeconds], + ) + + // Available channels feature (default: disabled; strict true) + result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true" + + // Affiliate (邀请返利) feature (default: disabled; strict true) + result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true" + + // 风控中心功能(默认关闭,严格 true 才启用) + result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true" + + // Claude Code version check + result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] + result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion] + + // 分组隔离 + result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true" + + // Gateway forwarding behavior (defaults: fingerprint=true, metadata_passthrough=false, cch_signing=false) + if v, ok := settings[SettingKeyEnableFingerprintUnification]; ok && v != "" { + result.EnableFingerprintUnification = v == "true" + } else { + result.EnableFingerprintUnification = true // default: enabled (current behavior) + } + result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" + result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" + result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" + if v, ok := settings[SettingKeyRewriteMessageCacheControl]; ok && v != "" { + result.RewriteMessageCacheControl = v == "true" + } else { + result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl() + } + result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion]) + + // Web search emulation: quick enabled check from the JSON config + if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" { + var wsCfg WebSearchEmulationConfig + if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil { + result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0 + } + } + result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource]) + result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource]) + result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true" + result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true" + result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true" + + // Balance low notification + result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" + if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { + result.BalanceLowNotifyThreshold = v + } + result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL] + + // Account quota notification + result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true" + if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" { + result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw) + } + if result.AccountQuotaNotifyEmails == nil { + result.AccountQuotaNotifyEmails = []NotifyEmailEntry{} + } + + return result +} + +func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + + var items []DefaultSubscriptionSetting + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil + } + + normalized := make([]DefaultSubscriptionSetting, 0, len(items)) + for _, item := range items { + if item.GroupID <= 0 || item.ValidityDays <= 0 { + continue + } + if item.ValidityDays > MaxValidityDays { + item.ValidityDays = MaxValidityDays + } + normalized = append(normalized, item) + } + + return normalized +} + +func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: defaultAuthSourceBalance, + Concurrency: defaultAuthSourceConcurrency, + Subscriptions: []DefaultSubscriptionSetting{}, + GrantOnSignup: false, + GrantOnFirstBind: false, + } + + if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil { + result.Balance = v + } + if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil { + result.Concurrency = v + } + if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil { + result.Subscriptions = items + } + if raw, ok := settings[keys.grantOnSignup]; ok { + result.GrantOnSignup = raw == "true" + } + if raw, ok := settings[keys.grantOnFirstBind]; ok { + result.GrantOnFirstBind = raw == "true" + } + + return result +} + +func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) { + updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64) + updates[keys.concurrency] = strconv.Itoa(settings.Concurrency) + + subscriptions := settings.Subscriptions + if subscriptions == nil { + subscriptions = []DefaultSubscriptionSetting{} + } + raw, err := json.Marshal(subscriptions) + if err != nil { + raw = []byte("[]") + } + updates[keys.subscriptions] = string(raw) + updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) + updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) +} + +func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings { + result := ProviderDefaultGrantSettings{ + Balance: globalDefaults.Balance, + Concurrency: globalDefaults.Concurrency, + Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...), + GrantOnSignup: providerDefaults.GrantOnSignup, + GrantOnFirstBind: providerDefaults.GrantOnFirstBind, + } + + if providerDefaults.Balance != defaultAuthSourceBalance { + result.Balance = providerDefaults.Balance + } + if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency { + result.Concurrency = providerDefaults.Concurrency + } + if len(providerDefaults.Subscriptions) > 0 { + result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...) + } + + return result +} + +func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { + defaultPageSize := 20 + if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { + defaultPageSize = v + } + + var options []int + if strings.TrimSpace(optionsRaw) != "" { + _ = json.Unmarshal([]byte(optionsRaw), &options) + } + + return normalizeTablePreferences(defaultPageSize, options) +} + +func normalizeTablePreferences(defaultPageSize int, options []int) (int, []int) { + const minPageSize = 5 + const maxPageSize = 1000 + const fallbackPageSize = 20 + + seen := make(map[int]struct{}, len(options)) + normalizedOptions := make([]int, 0, len(options)) + for _, option := range options { + if option < minPageSize || option > maxPageSize { + continue + } + if _, ok := seen[option]; ok { + continue + } + seen[option] = struct{}{} + normalizedOptions = append(normalizedOptions, option) + } + sort.Ints(normalizedOptions) + + if defaultPageSize < minPageSize || defaultPageSize > maxPageSize { + defaultPageSize = fallbackPageSize + } + + if len(normalizedOptions) == 0 { + normalizedOptions = []int{10, 20, 50} + } + + return defaultPageSize, normalizedOptions +} diff --git a/backend/internal/service/setting_cache.go b/backend/internal/service/setting_cache.go new file mode 100644 index 00000000000..56d0e5da946 --- /dev/null +++ b/backend/internal/service/setting_cache.go @@ -0,0 +1,169 @@ +package service + +import ( + "context" + "log/slog" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" +) + +// cachedVersionBounds 缓存 Claude Code 版本号上下限(进程内缓存,60s TTL) +type cachedVersionBounds struct { + min string // 空字符串 = 不检查 + max string // 空字符串 = 不检查 + expiresAt int64 // unix nano +} + +// versionBoundsCache 版本号上下限进程内缓存 +var versionBoundsCache atomic.Value + +// cachedBackendMode Backend Mode cache (in-process, 60s TTL) +type cachedBackendMode struct { + value bool + expiresAt int64 // unix nano +} + +var backendModeCache atomic.Value + +// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL) +type cachedGatewayForwardingSettings struct { + fingerprintUnification bool + metadataPassthrough bool + cchSigning bool + anthropicCacheTTL1hInjection bool + rewriteMessageCacheControl bool + expiresAt int64 // unix nano +} + +var gatewayForwardingCache atomic.Value + +// cachedAntigravityUserAgentVersion 缓存 Antigravity UA 版本号(进程内缓存,60s TTL) +type cachedAntigravityUserAgentVersion struct { + version string + expiresAt int64 // unix nano +} + +func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { + if settings == nil { + return + } + + // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 + versionBoundsSF.Forget("version_bounds") + versionBoundsCache.Store(&cachedVersionBounds{ + min: settings.MinClaudeCodeVersion, + max: settings.MaxClaudeCodeVersion, + expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), + }) + backendModeSF.Forget("backend_mode") + backendModeCache.Store(&cachedBackendMode{ + value: settings.BackendModeEnabled, + expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), + }) + gatewayForwardingSF.Forget("gateway_forwarding") + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: settings.EnableFingerprintUnification, + metadataPassthrough: settings.EnableMetadataPassthrough, + cchSigning: settings.EnableCCHSigning, + anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, + rewriteMessageCacheControl: settings.RewriteMessageCacheControl, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + }) + s.antigravityUAVersionSF.Forget("antigravity_user_agent_version") + antigravityUserAgentVersion := antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion) + if antigravityUserAgentVersion == "" { + antigravityUserAgentVersion = antigravity.GetDefaultUserAgentVersion() + } + s.antigravityUAVersionCache.Store(&cachedAntigravityUserAgentVersion{ + version: antigravityUserAgentVersion, + expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(), + }) + openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) + openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ + enabled: settings.OpenAIAdvancedSchedulerEnabled, + expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), + }) + if s.onUpdate != nil { + s.onUpdate() // Invalidate cache after settings update + } +} + +func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult { + if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return gatewayForwardingSettingsResult{ + fp: cached.fingerprintUnification, + mp: cached.metadataPassthrough, + cch: cached.cchSigning, + cacheTTL1h: cached.anthropicCacheTTL1hInjection, + rewriteMessageCacheControl: cached.rewriteMessageCacheControl, + } + } + } + val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) { + if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { + if time.Now().UnixNano() < cached.expiresAt { + return gatewayForwardingSettingsResult{ + fp: cached.fingerprintUnification, + mp: cached.metadataPassthrough, + cch: cached.cchSigning, + cacheTTL1h: cached.anthropicCacheTTL1hInjection, + rewriteMessageCacheControl: cached.rewriteMessageCacheControl, + }, nil + } + } + dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout) + defer cancel() + values, err := s.settingRepo.GetMultiple(dbCtx, []string{ + SettingKeyEnableFingerprintUnification, + SettingKeyEnableMetadataPassthrough, + SettingKeyEnableCCHSigning, + SettingKeyEnableAnthropicCacheTTL1hInjection, + SettingKeyRewriteMessageCacheControl, + }) + if err != nil { + slog.Warn("failed to get gateway forwarding settings", "error", err) + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: true, + metadataPassthrough: false, + cchSigning: false, + anthropicCacheTTL1hInjection: false, + rewriteMessageCacheControl: s.defaultRewriteMessageCacheControl(), + expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), + }) + return gatewayForwardingSettingsResult{fp: true, rewriteMessageCacheControl: s.defaultRewriteMessageCacheControl()}, nil + } + fp := true + if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" { + fp = v == "true" + } + mp := values[SettingKeyEnableMetadataPassthrough] == "true" + cch := values[SettingKeyEnableCCHSigning] == "true" + cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" + rewriteMessageCacheControl := s.defaultRewriteMessageCacheControl() + if v, ok := values[SettingKeyRewriteMessageCacheControl]; ok && v != "" { + rewriteMessageCacheControl = v == "true" + } + gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ + fingerprintUnification: fp, + metadataPassthrough: mp, + cchSigning: cch, + anthropicCacheTTL1hInjection: cacheTTL1h, + rewriteMessageCacheControl: rewriteMessageCacheControl, + expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), + }) + return gatewayForwardingSettingsResult{ + fp: fp, + mp: mp, + cch: cch, + cacheTTL1h: cacheTTL1h, + rewriteMessageCacheControl: rewriteMessageCacheControl, + }, nil + }) + if r, ok := val.(gatewayForwardingSettingsResult); ok { + return r + } + return gatewayForwardingSettingsResult{fp: true} +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 86978eecc4a..91d729c6283 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -11,7 +11,6 @@ import ( "log/slog" "math" "net/url" - "sort" "strconv" "strings" "sync/atomic" @@ -47,15 +46,7 @@ type SettingRepository interface { Delete(ctx context.Context, key string) error } -// cachedVersionBounds 缓存 Claude Code 版本号上下限(进程内缓存,60s TTL) -type cachedVersionBounds struct { - min string // 空字符串 = 不检查 - max string // 空字符串 = 不检查 - expiresAt int64 // unix nano -} - -// versionBoundsCache 版本号上下限进程内缓存 -var versionBoundsCache atomic.Value // *cachedVersionBounds +// *cachedVersionBounds // versionBoundsSF 防止缓存过期时 thundering herd var versionBoundsSF singleflight.Group @@ -69,42 +60,20 @@ const versionBoundsErrorTTL = 5 * time.Second // versionBoundsDBTimeout singleflight 内 DB 查询超时,独立于请求 context const versionBoundsDBTimeout = 5 * time.Second -// cachedBackendMode Backend Mode cache (in-process, 60s TTL) -type cachedBackendMode struct { - value bool - expiresAt int64 // unix nano -} - -var backendModeCache atomic.Value // *cachedBackendMode +// *cachedBackendMode var backendModeSF singleflight.Group const backendModeCacheTTL = 60 * time.Second const backendModeErrorTTL = 5 * time.Second const backendModeDBTimeout = 5 * time.Second -// cachedGatewayForwardingSettings 缓存网关转发行为设置(进程内缓存,60s TTL) -type cachedGatewayForwardingSettings struct { - fingerprintUnification bool - metadataPassthrough bool - cchSigning bool - anthropicCacheTTL1hInjection bool - rewriteMessageCacheControl bool - expiresAt int64 // unix nano -} - -var gatewayForwardingCache atomic.Value // *cachedGatewayForwardingSettings +// *cachedGatewayForwardingSettings var gatewayForwardingSF singleflight.Group const gatewayForwardingCacheTTL = 60 * time.Second const gatewayForwardingErrorTTL = 5 * time.Second const gatewayForwardingDBTimeout = 5 * time.Second -// cachedAntigravityUserAgentVersion 缓存 Antigravity UA 版本号(进程内缓存,60s TTL) -type cachedAntigravityUserAgentVersion struct { - version string - expiresAt int64 // unix nano -} - const antigravityUserAgentVersionCacheTTL = 60 * time.Second const antigravityUserAgentVersionErrorTTL = 5 * time.Second const antigravityUserAgentVersionDBTimeout = 5 * time.Second @@ -223,40 +192,6 @@ const ( defaultLoginAgreementDate = "2026-03-31" ) -func normalizeLoginAgreementMode(raw string) string { - switch strings.ToLower(strings.TrimSpace(raw)) { - case "checkbox": - return "checkbox" - default: - return defaultLoginAgreementMode - } -} - -func defaultLoginAgreementDocuments() []LoginAgreementDocument { - return []LoginAgreementDocument{ - { - ID: "terms", - Title: "服务条款", - ContentMD: "", - }, - { - ID: "usage-policy", - Title: "使用政策", - ContentMD: "", - }, - { - ID: "supported-regions", - Title: "支持的国家和地区", - ContentMD: "", - }, - { - ID: "service-specific-terms", - Title: "服务特定条款", - ContentMD: "", - }, - } -} - func normalizeLoginAgreementDocumentID(raw string) string { raw = strings.ToLower(strings.TrimSpace(raw)) var b strings.Builder @@ -353,17 +288,6 @@ func buildLoginAgreementRevision(updatedAt string, docs []LoginAgreementDocument return hex.EncodeToString(sum[:])[:16] } -func normalizeWeChatConnectModeSetting(raw string) string { - switch strings.ToLower(strings.TrimSpace(raw)) { - case "mp": - return "mp" - case "mobile": - return "mobile" - default: - return "open" - } -} - func defaultWeChatConnectScopeForMode(mode string) string { switch normalizeWeChatConnectModeSetting(mode) { case "mp": @@ -392,95 +316,6 @@ func normalizeWeChatConnectScopeSetting(raw, mode string) string { } } -func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bool, mode string) (bool, bool, bool) { - mode = normalizeWeChatConnectModeSetting(mode) - rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] - rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] - rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] - openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" - mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" - mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" - - if openConfigured || mpConfigured || mobileConfigured { - openEnabled := strings.TrimSpace(rawOpen) == "true" - mpEnabled := strings.TrimSpace(rawMP) == "true" - mobileEnabled := strings.TrimSpace(rawMobile) == "true" - return openEnabled, mpEnabled, mobileEnabled - } - - if !enabled { - return false, false, false - } - if mode == "mp" { - return false, true, false - } - if mode == "mobile" { - return false, false, true - } - return true, false, false -} - -func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string { - mode = normalizeWeChatConnectModeSetting(mode) - switch mode { - case "open": - if openEnabled { - return "open" - } - case "mp": - if mpEnabled { - return "mp" - } - case "mobile": - if mobileEnabled { - return "mobile" - } - } - switch { - case openEnabled: - return "open" - case mpEnabled: - return "mp" - case mobileEnabled: - return "mobile" - default: - return mode - } -} - -func mergeWeChatConnectCapabilitySettings(settings map[string]string, base config.WeChatConnectConfig, enabled bool, mode string) (bool, bool, bool) { - mode = normalizeWeChatConnectModeSetting(firstNonEmpty(mode, base.Mode)) - rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] - rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] - rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] - openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" - mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" - mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" - - if openConfigured || mpConfigured || mobileConfigured { - openEnabled := strings.TrimSpace(rawOpen) == "true" - mpEnabled := strings.TrimSpace(rawMP) == "true" - mobileEnabled := strings.TrimSpace(rawMobile) == "true" - _, enabledConfigured := settings[SettingKeyWeChatConnectEnabled] - if !enabledConfigured && - enabled && - !openEnabled && - !mpEnabled && - !mobileEnabled && - (base.OpenEnabled || base.MPEnabled || base.MobileEnabled) { - return base.OpenEnabled, base.MPEnabled, base.MobileEnabled - } - return openEnabled, mpEnabled, mobileEnabled - } - if !enabled { - return false, false, false - } - if base.OpenEnabled || base.MPEnabled || base.MobileEnabled { - return base.OpenEnabled, base.MPEnabled, base.MobileEnabled - } - return parseWeChatConnectCapabilitySettings(settings, enabled, mode) -} - func (s *SettingService) effectiveWeChatConnectOAuthConfig(settings map[string]string) WeChatConnectOAuthConfig { base := config.WeChatConnectConfig{} if s != nil && s.cfg != nil { @@ -555,200 +390,6 @@ func (s *SettingService) SetProxyRepository(repo ProxyRepository) { s.proxyRepo = repo } -// GetAllSettings 获取所有系统设置 -func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { - settings, err := s.settingRepo.GetAll(ctx) - if err != nil { - return nil, fmt.Errorf("get all settings: %w", err) - } - - return s.parseSettings(settings), nil -} - -// GetFrontendURL 获取前端基础URL(数据库优先,fallback 到配置文件) -func (s *SettingService) GetFrontendURL(ctx context.Context) string { - val, err := s.settingRepo.GetValue(ctx, SettingKeyFrontendURL) - if err == nil && strings.TrimSpace(val) != "" { - return strings.TrimSpace(val) - } - return s.cfg.Server.FrontendURL -} - -// GetPublicSettings 获取公开设置(无需登录) -func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings, error) { - keys := []string{ - SettingKeyRegistrationEnabled, - SettingKeyEmailVerifyEnabled, - SettingKeyForceEmailOnThirdPartySignup, - SettingKeyRegistrationEmailSuffixWhitelist, - SettingKeyPromoCodeEnabled, - SettingKeyPasswordResetEnabled, - SettingKeyInvitationCodeEnabled, - SettingKeyTotpEnabled, - SettingKeyLoginAgreementEnabled, - SettingKeyLoginAgreementMode, - SettingKeyLoginAgreementUpdatedAt, - SettingKeyLoginAgreementDocuments, - SettingKeyTurnstileEnabled, - SettingKeyTurnstileSiteKey, - SettingKeySiteName, - SettingKeySiteLogo, - SettingKeySiteSubtitle, - SettingKeyAPIBaseURL, - SettingKeyContactInfo, - SettingKeyDocURL, - SettingKeyHomeContent, - SettingKeyHideCcsImportButton, - SettingKeyPurchaseSubscriptionEnabled, - SettingKeyPurchaseSubscriptionURL, - SettingKeyTableDefaultPageSize, - SettingKeyTablePageSizeOptions, - SettingKeyCustomMenuItems, - SettingKeyCustomEndpoints, - SettingKeyLinuxDoConnectEnabled, - SettingKeyWeChatConnectEnabled, - SettingKeyWeChatConnectAppID, - SettingKeyWeChatConnectAppSecret, - SettingKeyWeChatConnectOpenAppID, - SettingKeyWeChatConnectOpenAppSecret, - SettingKeyWeChatConnectMPAppID, - SettingKeyWeChatConnectMPAppSecret, - SettingKeyWeChatConnectMobileAppID, - SettingKeyWeChatConnectMobileAppSecret, - SettingKeyWeChatConnectOpenEnabled, - SettingKeyWeChatConnectMPEnabled, - SettingKeyWeChatConnectMobileEnabled, - SettingKeyWeChatConnectMode, - SettingKeyWeChatConnectScopes, - SettingKeyWeChatConnectRedirectURL, - SettingKeyWeChatConnectFrontendRedirectURL, - SettingKeyBackendModeEnabled, - SettingPaymentEnabled, - SettingKeyOIDCConnectEnabled, - SettingKeyOIDCConnectProviderName, - SettingKeyGitHubOAuthEnabled, - SettingKeyGitHubOAuthClientID, - SettingKeyGitHubOAuthClientSecret, - SettingKeyGoogleOAuthEnabled, - SettingKeyGoogleOAuthClientID, - SettingKeyGoogleOAuthClientSecret, - SettingKeyBalanceLowNotifyEnabled, - SettingKeyBalanceLowNotifyThreshold, - SettingKeyBalanceLowNotifyRechargeURL, - SettingKeyAccountQuotaNotifyEnabled, - SettingKeyChannelMonitorEnabled, - SettingKeyChannelMonitorDefaultIntervalSeconds, - SettingKeyAvailableChannelsEnabled, - SettingKeyAffiliateEnabled, - SettingKeyRiskControlEnabled, - } - - settings, err := s.settingRepo.GetMultiple(ctx, keys) - if err != nil { - return nil, fmt.Errorf("get public settings: %w", err) - } - - linuxDoEnabled := false - if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { - linuxDoEnabled = raw == "true" - } else { - linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled - } - oidcEnabled := false - if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { - oidcEnabled = raw == "true" - } else { - oidcEnabled = s.cfg != nil && s.cfg.OIDC.Enabled - } - oidcProviderName := strings.TrimSpace(settings[SettingKeyOIDCConnectProviderName]) - if oidcProviderName == "" && s.cfg != nil { - oidcProviderName = strings.TrimSpace(s.cfg.OIDC.ProviderName) - } - if oidcProviderName == "" { - oidcProviderName = "OIDC" - } - gitHubEnabled := s.emailOAuthPublicEnabled(settings, "github") - googleEnabled := s.emailOAuthPublicEnabled(settings, "google") - weChatEnabled, weChatOpenEnabled, weChatMPEnabled, weChatMobileEnabled := s.weChatOAuthCapabilitiesFromSettings(settings) - - // Password reset requires email verification to be enabled - emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" - passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true" - registrationEmailSuffixWhitelist := ParseRegistrationEmailSuffixWhitelist( - settings[SettingKeyRegistrationEmailSuffixWhitelist], - ) - tableDefaultPageSize, tablePageSizeOptions := parseTablePreferences( - settings[SettingKeyTableDefaultPageSize], - settings[SettingKeyTablePageSizeOptions], - ) - loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments]) - loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt]) - if loginAgreementUpdatedAt == "" { - loginAgreementUpdatedAt = defaultLoginAgreementDate - } - - var balanceLowNotifyThreshold float64 - if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { - balanceLowNotifyThreshold = v - } - - return &PublicSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", - RegistrationEmailSuffixWhitelist: registrationEmailSuffixWhitelist, - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: passwordResetEnabled, - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true" && len(loginAgreementDocuments) > 0, - LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]), - LoginAgreementUpdatedAt: loginAgreementUpdatedAt, - LoginAgreementRevision: buildLoginAgreementRevision(loginAgreementUpdatedAt, loginAgreementDocuments), - LoginAgreementDocuments: loginAgreementDocuments, - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - TableDefaultPageSize: tableDefaultPageSize, - TablePageSizeOptions: tablePageSizeOptions, - CustomMenuItems: settings[SettingKeyCustomMenuItems], - CustomEndpoints: settings[SettingKeyCustomEndpoints], - LinuxDoOAuthEnabled: linuxDoEnabled, - WeChatOAuthEnabled: weChatEnabled, - WeChatOAuthOpenEnabled: weChatOpenEnabled, - WeChatOAuthMPEnabled: weChatMPEnabled, - WeChatOAuthMobileEnabled: weChatMobileEnabled, - BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", - PaymentEnabled: settings[SettingPaymentEnabled] == "true", - OIDCOAuthEnabled: oidcEnabled, - OIDCOAuthProviderName: oidcProviderName, - GitHubOAuthEnabled: gitHubEnabled, - GoogleOAuthEnabled: googleEnabled, - BalanceLowNotifyEnabled: settings[SettingKeyBalanceLowNotifyEnabled] == "true", - AccountQuotaNotifyEnabled: settings[SettingKeyAccountQuotaNotifyEnabled] == "true", - BalanceLowNotifyThreshold: balanceLowNotifyThreshold, - BalanceLowNotifyRechargeURL: settings[SettingKeyBalanceLowNotifyRechargeURL], - - ChannelMonitorEnabled: !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]), - ChannelMonitorDefaultIntervalSeconds: parseChannelMonitorInterval(settings[SettingKeyChannelMonitorDefaultIntervalSeconds]), - - AvailableChannelsEnabled: settings[SettingKeyAvailableChannelsEnabled] == "true", - - AffiliateEnabled: settings[SettingKeyAffiliateEnabled] == "true", - - RiskControlEnabled: settings[SettingKeyRiskControlEnabled] == "true", - }, nil -} - // channelMonitorIntervalMin / channelMonitorIntervalMax bound the default interval // (mirrors the monitor-level constraint but lives here so setting_service stays decoupled). const ( @@ -1018,60 +659,6 @@ func DefaultWeChatConnectScopesForMode(mode string) string { return defaultWeChatConnectScopeForMode(mode) } -func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) { - cfg := s.effectiveWeChatConnectOAuthConfig(settings) - - if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) { - return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") - } - if cfg.OpenEnabled { - if cfg.AppIDForMode("open") == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app id not configured") - } - if cfg.AppSecretForMode("open") == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app secret not configured") - } - } - if cfg.MPEnabled { - if cfg.AppIDForMode("mp") == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app id not configured") - } - if cfg.AppSecretForMode("mp") == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app secret not configured") - } - } - if cfg.MobileEnabled { - if cfg.AppIDForMode("mobile") == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app id not configured") - } - if cfg.AppSecretForMode("mobile") == "" { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured") - } - } - if v := strings.TrimSpace(cfg.RedirectURL); v != "" { - if err := config.ValidateAbsoluteHTTPURL(v); err != nil { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") - } - } - if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { - return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid") - } - return cfg, nil -} - -func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) { - cfg := s.effectiveWeChatConnectOAuthConfig(settings) - if !cfg.Enabled { - return false, false, false, false - } - - openReady := cfg.OpenEnabled && cfg.AppIDForMode("open") != "" && cfg.AppSecretForMode("open") != "" - mpReady := cfg.MPEnabled && cfg.AppIDForMode("mp") != "" && cfg.AppSecretForMode("mp") != "" - mobileReady := cfg.MobileEnabled && cfg.AppIDForMode("mobile") != "" && cfg.AppSecretForMode("mobile") != "" - - return openReady || mpReady, openReady, mpReady, mobileReady -} - func (s *SettingService) emailOAuthBaseConfig(provider string) config.EmailOAuthProviderConfig { switch strings.ToLower(strings.TrimSpace(provider)) { case "github": @@ -1136,34 +723,6 @@ func mergeEmailOAuthBaseConfig(base, override config.EmailOAuthProviderConfig) c return base } -func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool { - cfg := s.effectiveEmailOAuthConfig(settings, provider) - return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != "" -} - -func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig { - cfg := s.emailOAuthBaseConfig(provider) - switch strings.ToLower(strings.TrimSpace(provider)) { - case "github": - if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok { - cfg.Enabled = raw == "true" - } - cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID) - cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret) - cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL) - cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend) - case "google": - if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok { - cfg.Enabled = raw == "true" - } - cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID) - cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret) - cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL) - cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend) - } - return cfg -} - // filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON // array string, returning only items with visibility != "admin". func filterUserVisibleMenuItems(raw string) json.RawMessage { @@ -1286,44 +845,6 @@ func parseCustomMenuItemURLs(raw string) []string { return urls } -func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool { - if base.UsePKCEExplicit { - return base.UsePKCE - } - return true -} - -func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { - if base.ValidateIDTokenExplicit { - return base.ValidateIDToken - } - return true -} - -func oidcCompatibilityWriteDefault(base config.OIDCConnectConfig, configured bool, raw string, explicit bool, explicitValue bool) bool { - if configured { - return strings.TrimSpace(raw) == "true" - } - if explicit { - return explicitValue - } - return false -} - -// UpdateSettings 更新系统设置 -func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { - updates, err := s.buildSystemSettingsUpdates(ctx, settings) - if err != nil { - return err - } - - err = s.settingRepo.SetMultiple(ctx, updates) - if err == nil { - s.refreshCachedSettings(settings) - } - return err -} - func (s *SettingService) OIDCSecurityWriteDefaults(ctx context.Context) (bool, bool, error) { rawSettings, err := s.settingRepo.GetMultiple(ctx, []string{ SettingKeyOIDCConnectUsePKCE, @@ -1346,442 +867,10 @@ func (s *SettingService) OIDCSecurityWriteDefaults(ctx context.Context) (bool, b nil } -// UpdateSettingsWithAuthSourceDefaults persists system settings and auth-source defaults in a single write. -func (s *SettingService) UpdateSettingsWithAuthSourceDefaults(ctx context.Context, settings *SystemSettings, authDefaults *AuthSourceDefaultSettings) error { - updates, err := s.buildSystemSettingsUpdates(ctx, settings) - if err != nil { - return err - } - - authSourceUpdates, err := s.buildAuthSourceDefaultUpdates(ctx, authDefaults) - if err != nil { - return err - } - for key, value := range authSourceUpdates { - updates[key] = value - } - - err = s.settingRepo.SetMultiple(ctx, updates) - if err == nil { - s.refreshCachedSettings(settings) - } - return err -} - -func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, settings *SystemSettings) (map[string]string, error) { - if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { - return nil, err - } - normalizedWhitelist, err := NormalizeRegistrationEmailSuffixWhitelist(settings.RegistrationEmailSuffixWhitelist) - if err != nil { - return nil, infraerrors.BadRequest("INVALID_REGISTRATION_EMAIL_SUFFIX_WHITELIST", err.Error()) - } - if normalizedWhitelist == nil { - normalizedWhitelist = []string{} - } - settings.RegistrationEmailSuffixWhitelist = normalizedWhitelist - alipaySource, err := normalizeVisibleMethodSettingSource("alipay", settings.PaymentVisibleMethodAlipaySource, settings.PaymentVisibleMethodAlipayEnabled) - if err != nil { - return nil, err - } - wxpaySource, err := normalizeVisibleMethodSettingSource("wxpay", settings.PaymentVisibleMethodWxpaySource, settings.PaymentVisibleMethodWxpayEnabled) - if err != nil { - return nil, err - } - settings.PaymentVisibleMethodAlipaySource = alipaySource - settings.PaymentVisibleMethodWxpaySource = wxpaySource - settings.WeChatConnectAppID = strings.TrimSpace(settings.WeChatConnectAppID) - settings.WeChatConnectAppSecret = strings.TrimSpace(settings.WeChatConnectAppSecret) - settings.WeChatConnectOpenAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppID, settings.WeChatConnectAppID)) - settings.WeChatConnectOpenAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectOpenAppSecret, settings.WeChatConnectAppSecret)) - settings.WeChatConnectMPAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppID, settings.WeChatConnectAppID)) - settings.WeChatConnectMPAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMPAppSecret, settings.WeChatConnectAppSecret)) - settings.WeChatConnectMobileAppID = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppID, settings.WeChatConnectAppID)) - settings.WeChatConnectMobileAppSecret = strings.TrimSpace(firstNonEmpty(settings.WeChatConnectMobileAppSecret, settings.WeChatConnectAppSecret)) - settings.WeChatConnectMode = normalizeWeChatConnectStoredMode( - settings.WeChatConnectOpenEnabled, - settings.WeChatConnectMPEnabled, - settings.WeChatConnectMobileEnabled, - settings.WeChatConnectMode, - ) - settings.WeChatConnectScopes = normalizeWeChatConnectScopeSetting(settings.WeChatConnectScopes, settings.WeChatConnectMode) - settings.WeChatConnectRedirectURL = strings.TrimSpace(settings.WeChatConnectRedirectURL) - settings.WeChatConnectFrontendRedirectURL = strings.TrimSpace(settings.WeChatConnectFrontendRedirectURL) - if settings.WeChatConnectFrontendRedirectURL == "" { - settings.WeChatConnectFrontendRedirectURL = defaultWeChatConnectFrontend - } - settings.GitHubOAuthRedirectURL = strings.TrimSpace(settings.GitHubOAuthRedirectURL) - settings.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(settings.GitHubOAuthFrontendRedirectURL) - if settings.GitHubOAuthFrontendRedirectURL == "" { - settings.GitHubOAuthFrontendRedirectURL = defaultGitHubOAuthFrontend - } - settings.GoogleOAuthRedirectURL = strings.TrimSpace(settings.GoogleOAuthRedirectURL) - settings.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(settings.GoogleOAuthFrontendRedirectURL) - if settings.GoogleOAuthFrontendRedirectURL == "" { - settings.GoogleOAuthFrontendRedirectURL = defaultGoogleOAuthFrontend - } - - updates := make(map[string]string) - - // 注册设置 - updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled) - updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled) - registrationEmailSuffixWhitelistJSON, err := json.Marshal(settings.RegistrationEmailSuffixWhitelist) - if err != nil { - return nil, fmt.Errorf("marshal registration email suffix whitelist: %w", err) - } - updates[SettingKeyRegistrationEmailSuffixWhitelist] = string(registrationEmailSuffixWhitelistJSON) - updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled) - updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled) - updates[SettingKeyFrontendURL] = settings.FrontendURL - updates[SettingKeyInvitationCodeEnabled] = strconv.FormatBool(settings.InvitationCodeEnabled) - updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled) - settings.LoginAgreementMode = normalizeLoginAgreementMode(settings.LoginAgreementMode) - settings.LoginAgreementUpdatedAt = strings.TrimSpace(settings.LoginAgreementUpdatedAt) - if settings.LoginAgreementUpdatedAt == "" { - settings.LoginAgreementUpdatedAt = defaultLoginAgreementDate - } - loginAgreementDocumentsJSON, err := marshalLoginAgreementDocuments(settings.LoginAgreementDocuments) - if err != nil { - return nil, err - } - updates[SettingKeyLoginAgreementEnabled] = strconv.FormatBool(settings.LoginAgreementEnabled) - updates[SettingKeyLoginAgreementMode] = settings.LoginAgreementMode - updates[SettingKeyLoginAgreementUpdatedAt] = settings.LoginAgreementUpdatedAt - updates[SettingKeyLoginAgreementDocuments] = loginAgreementDocumentsJSON - - // 邮件服务设置(只有非空才更新密码) - updates[SettingKeySMTPHost] = settings.SMTPHost - updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort) - updates[SettingKeySMTPUsername] = settings.SMTPUsername - if settings.SMTPPassword != "" { - updates[SettingKeySMTPPassword] = settings.SMTPPassword - } - updates[SettingKeySMTPFrom] = settings.SMTPFrom - updates[SettingKeySMTPFromName] = settings.SMTPFromName - updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS) - - // Cloudflare Turnstile 设置(只有非空才更新密钥) - updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled) - updates[SettingKeyTurnstileSiteKey] = settings.TurnstileSiteKey - if settings.TurnstileSecretKey != "" { - updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey - } - - // LinuxDo Connect OAuth 登录 - updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled) - updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID - updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL - if settings.LinuxDoConnectClientSecret != "" { - updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret - } - - // Generic OIDC OAuth 登录 - updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled) - updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName - updates[SettingKeyOIDCConnectClientID] = settings.OIDCConnectClientID - updates[SettingKeyOIDCConnectIssuerURL] = settings.OIDCConnectIssuerURL - updates[SettingKeyOIDCConnectDiscoveryURL] = settings.OIDCConnectDiscoveryURL - updates[SettingKeyOIDCConnectAuthorizeURL] = settings.OIDCConnectAuthorizeURL - updates[SettingKeyOIDCConnectTokenURL] = settings.OIDCConnectTokenURL - updates[SettingKeyOIDCConnectUserInfoURL] = settings.OIDCConnectUserInfoURL - updates[SettingKeyOIDCConnectJWKSURL] = settings.OIDCConnectJWKSURL - updates[SettingKeyOIDCConnectScopes] = settings.OIDCConnectScopes - updates[SettingKeyOIDCConnectRedirectURL] = settings.OIDCConnectRedirectURL - updates[SettingKeyOIDCConnectFrontendRedirectURL] = settings.OIDCConnectFrontendRedirectURL - updates[SettingKeyOIDCConnectTokenAuthMethod] = settings.OIDCConnectTokenAuthMethod - updates[SettingKeyOIDCConnectUsePKCE] = strconv.FormatBool(settings.OIDCConnectUsePKCE) - updates[SettingKeyOIDCConnectValidateIDToken] = strconv.FormatBool(settings.OIDCConnectValidateIDToken) - updates[SettingKeyOIDCConnectAllowedSigningAlgs] = settings.OIDCConnectAllowedSigningAlgs - updates[SettingKeyOIDCConnectClockSkewSeconds] = strconv.Itoa(settings.OIDCConnectClockSkewSeconds) - updates[SettingKeyOIDCConnectRequireEmailVerified] = strconv.FormatBool(settings.OIDCConnectRequireEmailVerified) - updates[SettingKeyOIDCConnectUserInfoEmailPath] = settings.OIDCConnectUserInfoEmailPath - updates[SettingKeyOIDCConnectUserInfoIDPath] = settings.OIDCConnectUserInfoIDPath - updates[SettingKeyOIDCConnectUserInfoUsernamePath] = settings.OIDCConnectUserInfoUsernamePath - if settings.OIDCConnectClientSecret != "" { - updates[SettingKeyOIDCConnectClientSecret] = settings.OIDCConnectClientSecret - } - - // GitHub / Google 邮箱快捷登录 - updates[SettingKeyGitHubOAuthEnabled] = strconv.FormatBool(settings.GitHubOAuthEnabled) - updates[SettingKeyGitHubOAuthClientID] = strings.TrimSpace(settings.GitHubOAuthClientID) - updates[SettingKeyGitHubOAuthRedirectURL] = settings.GitHubOAuthRedirectURL - updates[SettingKeyGitHubOAuthFrontendRedirectURL] = settings.GitHubOAuthFrontendRedirectURL - if settings.GitHubOAuthClientSecret != "" { - updates[SettingKeyGitHubOAuthClientSecret] = strings.TrimSpace(settings.GitHubOAuthClientSecret) - } - updates[SettingKeyGoogleOAuthEnabled] = strconv.FormatBool(settings.GoogleOAuthEnabled) - updates[SettingKeyGoogleOAuthClientID] = strings.TrimSpace(settings.GoogleOAuthClientID) - updates[SettingKeyGoogleOAuthRedirectURL] = settings.GoogleOAuthRedirectURL - updates[SettingKeyGoogleOAuthFrontendRedirectURL] = settings.GoogleOAuthFrontendRedirectURL - if settings.GoogleOAuthClientSecret != "" { - updates[SettingKeyGoogleOAuthClientSecret] = strings.TrimSpace(settings.GoogleOAuthClientSecret) - } - - // WeChat Connect OAuth 登录 - updates[SettingKeyWeChatConnectEnabled] = strconv.FormatBool(settings.WeChatConnectEnabled) - updates[SettingKeyWeChatConnectAppID] = settings.WeChatConnectAppID - updates[SettingKeyWeChatConnectOpenAppID] = settings.WeChatConnectOpenAppID - updates[SettingKeyWeChatConnectMPAppID] = settings.WeChatConnectMPAppID - updates[SettingKeyWeChatConnectMobileAppID] = settings.WeChatConnectMobileAppID - updates[SettingKeyWeChatConnectOpenEnabled] = strconv.FormatBool(settings.WeChatConnectOpenEnabled) - updates[SettingKeyWeChatConnectMPEnabled] = strconv.FormatBool(settings.WeChatConnectMPEnabled) - updates[SettingKeyWeChatConnectMobileEnabled] = strconv.FormatBool(settings.WeChatConnectMobileEnabled) - updates[SettingKeyWeChatConnectMode] = settings.WeChatConnectMode - updates[SettingKeyWeChatConnectScopes] = settings.WeChatConnectScopes - updates[SettingKeyWeChatConnectRedirectURL] = settings.WeChatConnectRedirectURL - updates[SettingKeyWeChatConnectFrontendRedirectURL] = settings.WeChatConnectFrontendRedirectURL - if settings.WeChatConnectAppSecret != "" { - updates[SettingKeyWeChatConnectAppSecret] = settings.WeChatConnectAppSecret - } - if settings.WeChatConnectOpenAppSecret != "" { - updates[SettingKeyWeChatConnectOpenAppSecret] = settings.WeChatConnectOpenAppSecret - } - if settings.WeChatConnectMPAppSecret != "" { - updates[SettingKeyWeChatConnectMPAppSecret] = settings.WeChatConnectMPAppSecret - } - if settings.WeChatConnectMobileAppSecret != "" { - updates[SettingKeyWeChatConnectMobileAppSecret] = settings.WeChatConnectMobileAppSecret - } - - // OEM设置 - updates[SettingKeySiteName] = settings.SiteName - updates[SettingKeySiteLogo] = settings.SiteLogo - updates[SettingKeySiteSubtitle] = settings.SiteSubtitle - updates[SettingKeyAPIBaseURL] = settings.APIBaseURL - updates[SettingKeyContactInfo] = settings.ContactInfo - updates[SettingKeyDocURL] = settings.DocURL - updates[SettingKeyHomeContent] = settings.HomeContent - updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) - updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) - updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) - tableDefaultPageSize, tablePageSizeOptions := normalizeTablePreferences( - settings.TableDefaultPageSize, - settings.TablePageSizeOptions, - ) - updates[SettingKeyTableDefaultPageSize] = strconv.Itoa(tableDefaultPageSize) - tablePageSizeOptionsJSON, err := json.Marshal(tablePageSizeOptions) - if err != nil { - return nil, fmt.Errorf("marshal table page size options: %w", err) - } - updates[SettingKeyTablePageSizeOptions] = string(tablePageSizeOptionsJSON) - updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems - updates[SettingKeyCustomEndpoints] = settings.CustomEndpoints - - // 默认配置 - updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) - updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) - settings.AffiliateRebateRate = clampAffiliateRebateRate(settings.AffiliateRebateRate) - updates[SettingKeyAffiliateRebateRate] = strconv.FormatFloat(settings.AffiliateRebateRate, 'f', 8, 64) - if settings.AffiliateRebateFreezeHours < 0 { - settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursDefault - } - if settings.AffiliateRebateFreezeHours > AffiliateRebateFreezeHoursMax { - settings.AffiliateRebateFreezeHours = AffiliateRebateFreezeHoursMax - } - updates[SettingKeyAffiliateRebateFreezeHours] = strconv.Itoa(settings.AffiliateRebateFreezeHours) - if settings.AffiliateRebateDurationDays < 0 { - settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysDefault - } - if settings.AffiliateRebateDurationDays > AffiliateRebateDurationDaysMax { - settings.AffiliateRebateDurationDays = AffiliateRebateDurationDaysMax - } - updates[SettingKeyAffiliateRebateDurationDays] = strconv.Itoa(settings.AffiliateRebateDurationDays) - if settings.AffiliateRebatePerInviteeCap < 0 { - settings.AffiliateRebatePerInviteeCap = AffiliateRebatePerInviteeCapDefault - } - updates[SettingKeyAffiliateRebatePerInviteeCap] = strconv.FormatFloat(settings.AffiliateRebatePerInviteeCap, 'f', 8, 64) - updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) - defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) - if err != nil { - return nil, fmt.Errorf("marshal default subscriptions: %w", err) - } - updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) - - // Model fallback configuration - updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback) - updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic - updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI - updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini - updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity - - // Identity patch configuration (Claude -> Gemini) - updates[SettingKeyEnableIdentityPatch] = strconv.FormatBool(settings.EnableIdentityPatch) - updates[SettingKeyIdentityPatchPrompt] = settings.IdentityPatchPrompt - - // Ops monitoring (vNext) - updates[SettingKeyOpsMonitoringEnabled] = strconv.FormatBool(settings.OpsMonitoringEnabled) - updates[SettingKeyOpsRealtimeMonitoringEnabled] = strconv.FormatBool(settings.OpsRealtimeMonitoringEnabled) - updates[SettingKeyOpsQueryModeDefault] = string(ParseOpsQueryMode(settings.OpsQueryModeDefault)) - if settings.OpsMetricsIntervalSeconds > 0 { - updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) - } - - // Channel monitor feature switch - updates[SettingKeyChannelMonitorEnabled] = strconv.FormatBool(settings.ChannelMonitorEnabled) - if v := clampChannelMonitorInterval(settings.ChannelMonitorDefaultIntervalSeconds); v > 0 { - updates[SettingKeyChannelMonitorDefaultIntervalSeconds] = strconv.Itoa(v) - } - - // Available channels feature switch - updates[SettingKeyAvailableChannelsEnabled] = strconv.FormatBool(settings.AvailableChannelsEnabled) - - // Affiliate (邀请返利) feature switch - updates[SettingKeyAffiliateEnabled] = strconv.FormatBool(settings.AffiliateEnabled) - - // 风控中心功能开关 - updates[SettingKeyRiskControlEnabled] = strconv.FormatBool(settings.RiskControlEnabled) - - // Claude Code version check - updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion - updates[SettingKeyMaxClaudeCodeVersion] = settings.MaxClaudeCodeVersion - - // 分组隔离 - updates[SettingKeyAllowUngroupedKeyScheduling] = strconv.FormatBool(settings.AllowUngroupedKeyScheduling) - - // Backend Mode - updates[SettingKeyBackendModeEnabled] = strconv.FormatBool(settings.BackendModeEnabled) - - // Gateway forwarding behavior - updates[SettingKeyEnableFingerprintUnification] = strconv.FormatBool(settings.EnableFingerprintUnification) - updates[SettingKeyEnableMetadataPassthrough] = strconv.FormatBool(settings.EnableMetadataPassthrough) - updates[SettingKeyEnableCCHSigning] = strconv.FormatBool(settings.EnableCCHSigning) - updates[SettingKeyEnableAnthropicCacheTTL1hInjection] = strconv.FormatBool(settings.EnableAnthropicCacheTTL1hInjection) - updates[SettingKeyRewriteMessageCacheControl] = strconv.FormatBool(settings.RewriteMessageCacheControl) - updates[SettingKeyAntigravityUserAgentVersion] = antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion) - updates[SettingPaymentVisibleMethodAlipaySource] = settings.PaymentVisibleMethodAlipaySource - updates[SettingPaymentVisibleMethodWxpaySource] = settings.PaymentVisibleMethodWxpaySource - updates[SettingPaymentVisibleMethodAlipayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodAlipayEnabled) - updates[SettingPaymentVisibleMethodWxpayEnabled] = strconv.FormatBool(settings.PaymentVisibleMethodWxpayEnabled) - updates[openAIAdvancedSchedulerSettingKey] = strconv.FormatBool(settings.OpenAIAdvancedSchedulerEnabled) - - // Balance low notification - updates[SettingKeyBalanceLowNotifyEnabled] = strconv.FormatBool(settings.BalanceLowNotifyEnabled) - updates[SettingKeyBalanceLowNotifyThreshold] = strconv.FormatFloat(settings.BalanceLowNotifyThreshold, 'f', 8, 64) - updates[SettingKeyBalanceLowNotifyRechargeURL] = settings.BalanceLowNotifyRechargeURL - updates[SettingKeyAccountQuotaNotifyEnabled] = strconv.FormatBool(settings.AccountQuotaNotifyEnabled) - updates[SettingKeyAccountQuotaNotifyEmails] = MarshalNotifyEmails(settings.AccountQuotaNotifyEmails) - - return updates, nil -} - -func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, settings *AuthSourceDefaultSettings) (map[string]string, error) { - if settings == nil { - return nil, nil - } - - for _, subscriptions := range [][]DefaultSubscriptionSetting{ - settings.Email.Subscriptions, - settings.LinuxDo.Subscriptions, - settings.OIDC.Subscriptions, - settings.WeChat.Subscriptions, - settings.GitHub.Subscriptions, - settings.Google.Subscriptions, - } { - if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { - return nil, err - } - } - - updates := make(map[string]string, 31) - writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) - writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) - writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) - writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) - writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub) - writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google) - updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) - return updates, nil -} - -func (s *SettingService) refreshCachedSettings(settings *SystemSettings) { - if settings == nil { - return - } - - // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 - versionBoundsSF.Forget("version_bounds") - versionBoundsCache.Store(&cachedVersionBounds{ - min: settings.MinClaudeCodeVersion, - max: settings.MaxClaudeCodeVersion, - expiresAt: time.Now().Add(versionBoundsCacheTTL).UnixNano(), - }) - backendModeSF.Forget("backend_mode") - backendModeCache.Store(&cachedBackendMode{ - value: settings.BackendModeEnabled, - expiresAt: time.Now().Add(backendModeCacheTTL).UnixNano(), - }) - gatewayForwardingSF.Forget("gateway_forwarding") - gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: settings.EnableFingerprintUnification, - metadataPassthrough: settings.EnableMetadataPassthrough, - cchSigning: settings.EnableCCHSigning, - anthropicCacheTTL1hInjection: settings.EnableAnthropicCacheTTL1hInjection, - rewriteMessageCacheControl: settings.RewriteMessageCacheControl, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), - }) - s.antigravityUAVersionSF.Forget("antigravity_user_agent_version") - antigravityUserAgentVersion := antigravity.NormalizeUserAgentVersion(settings.AntigravityUserAgentVersion) - if antigravityUserAgentVersion == "" { - antigravityUserAgentVersion = antigravity.GetDefaultUserAgentVersion() - } - s.antigravityUAVersionCache.Store(&cachedAntigravityUserAgentVersion{ - version: antigravityUserAgentVersion, - expiresAt: time.Now().Add(antigravityUserAgentVersionCacheTTL).UnixNano(), - }) - openAIAdvancedSchedulerSettingSF.Forget(openAIAdvancedSchedulerSettingKey) - openAIAdvancedSchedulerSettingCache.Store(&cachedOpenAIAdvancedSchedulerSetting{ - enabled: settings.OpenAIAdvancedSchedulerEnabled, - expiresAt: time.Now().Add(openAIAdvancedSchedulerSettingCacheTTL).UnixNano(), - }) - if s.onUpdate != nil { - s.onUpdate() // Invalidate cache after settings update - } -} - func (s *SettingService) defaultRewriteMessageCacheControl() bool { return false } -func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { - if len(items) == 0 { - return nil - } - - checked := make(map[int64]struct{}, len(items)) - for _, item := range items { - if item.GroupID <= 0 { - continue - } - if _, ok := checked[item.GroupID]; ok { - return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{ - "group_id": strconv.FormatInt(item.GroupID, 10), - }) - } - checked[item.GroupID] = struct{}{} - if s.defaultSubGroupReader == nil { - continue - } - - group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID) - if err != nil { - if errors.Is(err, ErrGroupNotFound) { - return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ - "group_id": strconv.FormatInt(item.GroupID, 10), - }) - } - return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err) - } - if !group.IsSubscriptionType() { - return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ - "group_id": strconv.FormatInt(item.GroupID, 10), - }) - } - } - - return nil -} - func (s *SettingService) GetEmailOAuthProviderConfig(ctx context.Context, provider string) (config.EmailOAuthProviderConfig, error) { provider = strings.ToLower(strings.TrimSpace(provider)) if provider != "github" && provider != "google" { @@ -1897,84 +986,6 @@ type gatewayForwardingSettingsResult struct { fp, mp, cch, cacheTTL1h, rewriteMessageCacheControl bool } -func (s *SettingService) getGatewayForwardingSettingsCached(ctx context.Context) gatewayForwardingSettingsResult { - if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { - if time.Now().UnixNano() < cached.expiresAt { - return gatewayForwardingSettingsResult{ - fp: cached.fingerprintUnification, - mp: cached.metadataPassthrough, - cch: cached.cchSigning, - cacheTTL1h: cached.anthropicCacheTTL1hInjection, - rewriteMessageCacheControl: cached.rewriteMessageCacheControl, - } - } - } - val, _, _ := gatewayForwardingSF.Do("gateway_forwarding", func() (any, error) { - if cached, ok := gatewayForwardingCache.Load().(*cachedGatewayForwardingSettings); ok && cached != nil { - if time.Now().UnixNano() < cached.expiresAt { - return gatewayForwardingSettingsResult{ - fp: cached.fingerprintUnification, - mp: cached.metadataPassthrough, - cch: cached.cchSigning, - cacheTTL1h: cached.anthropicCacheTTL1hInjection, - rewriteMessageCacheControl: cached.rewriteMessageCacheControl, - }, nil - } - } - dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), gatewayForwardingDBTimeout) - defer cancel() - values, err := s.settingRepo.GetMultiple(dbCtx, []string{ - SettingKeyEnableFingerprintUnification, - SettingKeyEnableMetadataPassthrough, - SettingKeyEnableCCHSigning, - SettingKeyEnableAnthropicCacheTTL1hInjection, - SettingKeyRewriteMessageCacheControl, - }) - if err != nil { - slog.Warn("failed to get gateway forwarding settings", "error", err) - gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: true, - metadataPassthrough: false, - cchSigning: false, - anthropicCacheTTL1hInjection: false, - rewriteMessageCacheControl: s.defaultRewriteMessageCacheControl(), - expiresAt: time.Now().Add(gatewayForwardingErrorTTL).UnixNano(), - }) - return gatewayForwardingSettingsResult{fp: true, rewriteMessageCacheControl: s.defaultRewriteMessageCacheControl()}, nil - } - fp := true - if v, ok := values[SettingKeyEnableFingerprintUnification]; ok && v != "" { - fp = v == "true" - } - mp := values[SettingKeyEnableMetadataPassthrough] == "true" - cch := values[SettingKeyEnableCCHSigning] == "true" - cacheTTL1h := values[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" - rewriteMessageCacheControl := s.defaultRewriteMessageCacheControl() - if v, ok := values[SettingKeyRewriteMessageCacheControl]; ok && v != "" { - rewriteMessageCacheControl = v == "true" - } - gatewayForwardingCache.Store(&cachedGatewayForwardingSettings{ - fingerprintUnification: fp, - metadataPassthrough: mp, - cchSigning: cch, - anthropicCacheTTL1hInjection: cacheTTL1h, - rewriteMessageCacheControl: rewriteMessageCacheControl, - expiresAt: time.Now().Add(gatewayForwardingCacheTTL).UnixNano(), - }) - return gatewayForwardingSettingsResult{ - fp: fp, - mp: mp, - cch: cch, - cacheTTL1h: cacheTTL1h, - rewriteMessageCacheControl: rewriteMessageCacheControl, - }, nil - }) - if r, ok := val.(gatewayForwardingSettingsResult); ok { - return r - } - return gatewayForwardingSettingsResult{fp: true} -} - // GetGatewayForwardingSettings returns cached gateway forwarding settings. // Uses in-process atomic.Value cache with 60s TTL, zero-lock hot path. // Returns (fingerprintUnification, metadataPassthrough, cchSigning). @@ -2469,411 +1480,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { return s.settingRepo.SetMultiple(ctx, defaults) } -// parseSettings 解析设置到结构体 -func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings { - emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true" - loginAgreementDocuments := parseLoginAgreementDocuments(settings[SettingKeyLoginAgreementDocuments]) - loginAgreementUpdatedAt := strings.TrimSpace(settings[SettingKeyLoginAgreementUpdatedAt]) - if loginAgreementUpdatedAt == "" { - loginAgreementUpdatedAt = defaultLoginAgreementDate - } - result := &SystemSettings{ - RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", - EmailVerifyEnabled: emailVerifyEnabled, - RegistrationEmailSuffixWhitelist: ParseRegistrationEmailSuffixWhitelist(settings[SettingKeyRegistrationEmailSuffixWhitelist]), - PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用 - PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true", - FrontendURL: settings[SettingKeyFrontendURL], - InvitationCodeEnabled: settings[SettingKeyInvitationCodeEnabled] == "true", - TotpEnabled: settings[SettingKeyTotpEnabled] == "true", - LoginAgreementEnabled: settings[SettingKeyLoginAgreementEnabled] == "true", - LoginAgreementMode: normalizeLoginAgreementMode(settings[SettingKeyLoginAgreementMode]), - LoginAgreementUpdatedAt: loginAgreementUpdatedAt, - LoginAgreementDocuments: loginAgreementDocuments, - SMTPHost: settings[SettingKeySMTPHost], - SMTPUsername: settings[SettingKeySMTPUsername], - SMTPFrom: settings[SettingKeySMTPFrom], - SMTPFromName: settings[SettingKeySMTPFromName], - SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true", - SMTPPasswordConfigured: settings[SettingKeySMTPPassword] != "", - TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true", - TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey], - TurnstileSecretKeyConfigured: settings[SettingKeyTurnstileSecretKey] != "", - SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"), - SiteLogo: settings[SettingKeySiteLogo], - SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"), - APIBaseURL: settings[SettingKeyAPIBaseURL], - ContactInfo: settings[SettingKeyContactInfo], - DocURL: settings[SettingKeyDocURL], - HomeContent: settings[SettingKeyHomeContent], - HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", - PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", - PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), - CustomMenuItems: settings[SettingKeyCustomMenuItems], - CustomEndpoints: settings[SettingKeyCustomEndpoints], - BackendModeEnabled: settings[SettingKeyBackendModeEnabled] == "true", - } - result.TableDefaultPageSize, result.TablePageSizeOptions = parseTablePreferences( - settings[SettingKeyTableDefaultPageSize], - settings[SettingKeyTablePageSizeOptions], - ) - - // 解析整数类型 - if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil { - result.SMTPPort = port - } else { - result.SMTPPort = 587 - } - - if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil { - result.DefaultConcurrency = concurrency - } else { - result.DefaultConcurrency = s.cfg.Default.UserConcurrency - } - - if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 { - result.DefaultUserRPMLimit = rpm - } - - // 解析浮点数类型 - if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil { - result.DefaultBalance = balance - } else { - result.DefaultBalance = s.cfg.Default.UserBalance - } - if rebateRate, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebateRate], 64); err == nil { - result.AffiliateRebateRate = clampAffiliateRebateRate(rebateRate) - } else { - result.AffiliateRebateRate = AffiliateRebateRateDefault - } - if freezeHours, err := strconv.Atoi(settings[SettingKeyAffiliateRebateFreezeHours]); err == nil && freezeHours >= 0 { - if freezeHours > AffiliateRebateFreezeHoursMax { - freezeHours = AffiliateRebateFreezeHoursMax - } - result.AffiliateRebateFreezeHours = freezeHours - } - if durationDays, err := strconv.Atoi(settings[SettingKeyAffiliateRebateDurationDays]); err == nil && durationDays >= 0 { - if durationDays > AffiliateRebateDurationDaysMax { - durationDays = AffiliateRebateDurationDaysMax - } - result.AffiliateRebateDurationDays = durationDays - } - if perInviteeCap, err := strconv.ParseFloat(settings[SettingKeyAffiliateRebatePerInviteeCap], 64); err == nil && perInviteeCap >= 0 { - result.AffiliateRebatePerInviteeCap = perInviteeCap - } - result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) - - // 敏感信息直接返回,方便测试连接时使用 - result.SMTPPassword = settings[SettingKeySMTPPassword] - result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] - - // LinuxDo Connect 设置: - // - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭) - // - 支持在后台“系统设置”中覆盖并持久化(存储于 DB) - linuxDoBase := config.LinuxDoConnectConfig{} - if s.cfg != nil { - linuxDoBase = s.cfg.LinuxDo - } - - if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { - result.LinuxDoConnectEnabled = raw == "true" - } else { - result.LinuxDoConnectEnabled = linuxDoBase.Enabled - } - - if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { - result.LinuxDoConnectClientID = strings.TrimSpace(v) - } else { - result.LinuxDoConnectClientID = linuxDoBase.ClientID - } - - if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { - result.LinuxDoConnectRedirectURL = strings.TrimSpace(v) - } else { - result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL - } - - result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret]) - if result.LinuxDoConnectClientSecret == "" { - result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret) - } - result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" - - // Generic OIDC 设置: - // - 兼容 config.yaml/env - // - 支持后台系统设置覆盖并持久化(存储于 DB) - oidcBase := config.OIDCConnectConfig{} - if s.cfg != nil { - oidcBase = s.cfg.OIDC - } - - if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { - result.OIDCConnectEnabled = raw == "true" - } else { - result.OIDCConnectEnabled = oidcBase.Enabled - } - - if v, ok := settings[SettingKeyOIDCConnectProviderName]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectProviderName = strings.TrimSpace(v) - } else { - result.OIDCConnectProviderName = strings.TrimSpace(oidcBase.ProviderName) - } - if result.OIDCConnectProviderName == "" { - result.OIDCConnectProviderName = "OIDC" - } - - if v, ok := settings[SettingKeyOIDCConnectClientID]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectClientID = strings.TrimSpace(v) - } else { - result.OIDCConnectClientID = strings.TrimSpace(oidcBase.ClientID) - } - if v, ok := settings[SettingKeyOIDCConnectIssuerURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectIssuerURL = strings.TrimSpace(v) - } else { - result.OIDCConnectIssuerURL = strings.TrimSpace(oidcBase.IssuerURL) - } - if v, ok := settings[SettingKeyOIDCConnectDiscoveryURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectDiscoveryURL = strings.TrimSpace(v) - } else { - result.OIDCConnectDiscoveryURL = strings.TrimSpace(oidcBase.DiscoveryURL) - } - if v, ok := settings[SettingKeyOIDCConnectAuthorizeURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectAuthorizeURL = strings.TrimSpace(v) - } else { - result.OIDCConnectAuthorizeURL = strings.TrimSpace(oidcBase.AuthorizeURL) - } - if v, ok := settings[SettingKeyOIDCConnectTokenURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectTokenURL = strings.TrimSpace(v) - } else { - result.OIDCConnectTokenURL = strings.TrimSpace(oidcBase.TokenURL) - } - if v, ok := settings[SettingKeyOIDCConnectUserInfoURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectUserInfoURL = strings.TrimSpace(v) - } else { - result.OIDCConnectUserInfoURL = strings.TrimSpace(oidcBase.UserInfoURL) - } - if v, ok := settings[SettingKeyOIDCConnectJWKSURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectJWKSURL = strings.TrimSpace(v) - } else { - result.OIDCConnectJWKSURL = strings.TrimSpace(oidcBase.JWKSURL) - } - if v, ok := settings[SettingKeyOIDCConnectScopes]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectScopes = strings.TrimSpace(v) - } else { - result.OIDCConnectScopes = strings.TrimSpace(oidcBase.Scopes) - } - if v, ok := settings[SettingKeyOIDCConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectRedirectURL = strings.TrimSpace(v) - } else { - result.OIDCConnectRedirectURL = strings.TrimSpace(oidcBase.RedirectURL) - } - if v, ok := settings[SettingKeyOIDCConnectFrontendRedirectURL]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(v) - } else { - result.OIDCConnectFrontendRedirectURL = strings.TrimSpace(oidcBase.FrontendRedirectURL) - } - if v, ok := settings[SettingKeyOIDCConnectTokenAuthMethod]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(v)) - } else { - result.OIDCConnectTokenAuthMethod = strings.ToLower(strings.TrimSpace(oidcBase.TokenAuthMethod)) - } - if raw, ok := settings[SettingKeyOIDCConnectUsePKCE]; ok { - result.OIDCConnectUsePKCE = raw == "true" - } else { - result.OIDCConnectUsePKCE = oidcUsePKCECompatibilityDefault(oidcBase) - } - if raw, ok := settings[SettingKeyOIDCConnectValidateIDToken]; ok { - result.OIDCConnectValidateIDToken = raw == "true" - } else { - result.OIDCConnectValidateIDToken = oidcValidateIDTokenCompatibilityDefault(oidcBase) - } - if v, ok := settings[SettingKeyOIDCConnectAllowedSigningAlgs]; ok && strings.TrimSpace(v) != "" { - result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(v) - } else { - result.OIDCConnectAllowedSigningAlgs = strings.TrimSpace(oidcBase.AllowedSigningAlgs) - } - clockSkewSet := false - if raw, ok := settings[SettingKeyOIDCConnectClockSkewSeconds]; ok && strings.TrimSpace(raw) != "" { - if parsed, err := strconv.Atoi(strings.TrimSpace(raw)); err == nil { - result.OIDCConnectClockSkewSeconds = parsed - clockSkewSet = true - } - } - if !clockSkewSet { - result.OIDCConnectClockSkewSeconds = oidcBase.ClockSkewSeconds - } - if !clockSkewSet && result.OIDCConnectClockSkewSeconds == 0 { - result.OIDCConnectClockSkewSeconds = 120 - } - if raw, ok := settings[SettingKeyOIDCConnectRequireEmailVerified]; ok { - result.OIDCConnectRequireEmailVerified = raw == "true" - } else { - result.OIDCConnectRequireEmailVerified = oidcBase.RequireEmailVerified - } - if v, ok := settings[SettingKeyOIDCConnectUserInfoEmailPath]; ok { - result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(v) - } else { - result.OIDCConnectUserInfoEmailPath = strings.TrimSpace(oidcBase.UserInfoEmailPath) - } - if v, ok := settings[SettingKeyOIDCConnectUserInfoIDPath]; ok { - result.OIDCConnectUserInfoIDPath = strings.TrimSpace(v) - } else { - result.OIDCConnectUserInfoIDPath = strings.TrimSpace(oidcBase.UserInfoIDPath) - } - if v, ok := settings[SettingKeyOIDCConnectUserInfoUsernamePath]; ok { - result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(v) - } else { - result.OIDCConnectUserInfoUsernamePath = strings.TrimSpace(oidcBase.UserInfoUsernamePath) - } - result.OIDCConnectClientSecret = strings.TrimSpace(settings[SettingKeyOIDCConnectClientSecret]) - if result.OIDCConnectClientSecret == "" { - result.OIDCConnectClientSecret = strings.TrimSpace(oidcBase.ClientSecret) - } - result.OIDCConnectClientSecretConfigured = result.OIDCConnectClientSecret != "" - - gitHubEffective := s.effectiveEmailOAuthConfig(settings, "github") - result.GitHubOAuthEnabled = gitHubEffective.Enabled - result.GitHubOAuthClientID = strings.TrimSpace(gitHubEffective.ClientID) - result.GitHubOAuthClientSecret = strings.TrimSpace(gitHubEffective.ClientSecret) - result.GitHubOAuthClientSecretConfigured = result.GitHubOAuthClientSecret != "" - result.GitHubOAuthRedirectURL = strings.TrimSpace(gitHubEffective.RedirectURL) - result.GitHubOAuthFrontendRedirectURL = strings.TrimSpace(gitHubEffective.FrontendRedirectURL) - - googleEffective := s.effectiveEmailOAuthConfig(settings, "google") - result.GoogleOAuthEnabled = googleEffective.Enabled - result.GoogleOAuthClientID = strings.TrimSpace(googleEffective.ClientID) - result.GoogleOAuthClientSecret = strings.TrimSpace(googleEffective.ClientSecret) - result.GoogleOAuthClientSecretConfigured = result.GoogleOAuthClientSecret != "" - result.GoogleOAuthRedirectURL = strings.TrimSpace(googleEffective.RedirectURL) - result.GoogleOAuthFrontendRedirectURL = strings.TrimSpace(googleEffective.FrontendRedirectURL) - - // WeChat Connect 设置: - // - 优先读取 DB 系统设置 - // - 缺失时回退到 config/env,保持升级兼容 - weChatEffective := s.effectiveWeChatConnectOAuthConfig(settings) - result.WeChatConnectEnabled = weChatEffective.Enabled - result.WeChatConnectAppID = weChatEffective.LegacyAppID - result.WeChatConnectAppSecret = weChatEffective.LegacyAppSecret - result.WeChatConnectAppSecretConfigured = weChatEffective.LegacyAppSecret != "" - result.WeChatConnectOpenAppID = weChatEffective.OpenAppID - result.WeChatConnectOpenAppSecret = weChatEffective.OpenAppSecret - result.WeChatConnectOpenAppSecretConfigured = weChatEffective.OpenAppSecret != "" - result.WeChatConnectMPAppID = weChatEffective.MPAppID - result.WeChatConnectMPAppSecret = weChatEffective.MPAppSecret - result.WeChatConnectMPAppSecretConfigured = weChatEffective.MPAppSecret != "" - result.WeChatConnectMobileAppID = weChatEffective.MobileAppID - result.WeChatConnectMobileAppSecret = weChatEffective.MobileAppSecret - result.WeChatConnectMobileAppSecretConfigured = weChatEffective.MobileAppSecret != "" - result.WeChatConnectOpenEnabled = weChatEffective.OpenEnabled - result.WeChatConnectMPEnabled = weChatEffective.MPEnabled - result.WeChatConnectMobileEnabled = weChatEffective.MobileEnabled - result.WeChatConnectMode = weChatEffective.Mode - result.WeChatConnectScopes = weChatEffective.Scopes - result.WeChatConnectRedirectURL = weChatEffective.RedirectURL - result.WeChatConnectFrontendRedirectURL = weChatEffective.FrontendRedirectURL - - // Model fallback settings - result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" - result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") - result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o") - result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro") - result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro") - - // Identity patch settings (default: enabled, to preserve existing behavior) - if v, ok := settings[SettingKeyEnableIdentityPatch]; ok && v != "" { - result.EnableIdentityPatch = v == "true" - } else { - result.EnableIdentityPatch = true - } - result.IdentityPatchPrompt = settings[SettingKeyIdentityPatchPrompt] - - // Ops monitoring settings (default: enabled, fail-open) - result.OpsMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsMonitoringEnabled]) - result.OpsRealtimeMonitoringEnabled = !isFalseSettingValue(settings[SettingKeyOpsRealtimeMonitoringEnabled]) - result.OpsQueryModeDefault = string(ParseOpsQueryMode(settings[SettingKeyOpsQueryModeDefault])) - result.OpsMetricsIntervalSeconds = 60 - if raw := strings.TrimSpace(settings[SettingKeyOpsMetricsIntervalSeconds]); raw != "" { - if v, err := strconv.Atoi(raw); err == nil { - if v < 60 { - v = 60 - } - if v > 3600 { - v = 3600 - } - result.OpsMetricsIntervalSeconds = v - } - } - - // Channel monitor feature (default: enabled, 60s) - result.ChannelMonitorEnabled = !isFalseSettingValue(settings[SettingKeyChannelMonitorEnabled]) - result.ChannelMonitorDefaultIntervalSeconds = parseChannelMonitorInterval( - settings[SettingKeyChannelMonitorDefaultIntervalSeconds], - ) - - // Available channels feature (default: disabled; strict true) - result.AvailableChannelsEnabled = settings[SettingKeyAvailableChannelsEnabled] == "true" - - // Affiliate (邀请返利) feature (default: disabled; strict true) - result.AffiliateEnabled = settings[SettingKeyAffiliateEnabled] == "true" - - // 风控中心功能(默认关闭,严格 true 才启用) - result.RiskControlEnabled = settings[SettingKeyRiskControlEnabled] == "true" - - // Claude Code version check - result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] - result.MaxClaudeCodeVersion = settings[SettingKeyMaxClaudeCodeVersion] - - // 分组隔离 - result.AllowUngroupedKeyScheduling = settings[SettingKeyAllowUngroupedKeyScheduling] == "true" - - // Gateway forwarding behavior (defaults: fingerprint=true, metadata_passthrough=false, cch_signing=false) - if v, ok := settings[SettingKeyEnableFingerprintUnification]; ok && v != "" { - result.EnableFingerprintUnification = v == "true" - } else { - result.EnableFingerprintUnification = true // default: enabled (current behavior) - } - result.EnableMetadataPassthrough = settings[SettingKeyEnableMetadataPassthrough] == "true" - result.EnableCCHSigning = settings[SettingKeyEnableCCHSigning] == "true" - result.EnableAnthropicCacheTTL1hInjection = settings[SettingKeyEnableAnthropicCacheTTL1hInjection] == "true" - if v, ok := settings[SettingKeyRewriteMessageCacheControl]; ok && v != "" { - result.RewriteMessageCacheControl = v == "true" - } else { - result.RewriteMessageCacheControl = s.defaultRewriteMessageCacheControl() - } - result.AntigravityUserAgentVersion = antigravity.NormalizeUserAgentVersion(settings[SettingKeyAntigravityUserAgentVersion]) - - // Web search emulation: quick enabled check from the JSON config - if raw := settings[SettingKeyWebSearchEmulationConfig]; raw != "" { - var wsCfg WebSearchEmulationConfig - if err := json.Unmarshal([]byte(raw), &wsCfg); err == nil { - result.WebSearchEmulationEnabled = wsCfg.Enabled && len(wsCfg.Providers) > 0 - } - } - result.PaymentVisibleMethodAlipaySource = NormalizeVisibleMethodSource("alipay", settings[SettingPaymentVisibleMethodAlipaySource]) - result.PaymentVisibleMethodWxpaySource = NormalizeVisibleMethodSource("wxpay", settings[SettingPaymentVisibleMethodWxpaySource]) - result.PaymentVisibleMethodAlipayEnabled = settings[SettingPaymentVisibleMethodAlipayEnabled] == "true" - result.PaymentVisibleMethodWxpayEnabled = settings[SettingPaymentVisibleMethodWxpayEnabled] == "true" - result.OpenAIAdvancedSchedulerEnabled = settings[openAIAdvancedSchedulerSettingKey] == "true" - - // Balance low notification - result.BalanceLowNotifyEnabled = settings[SettingKeyBalanceLowNotifyEnabled] == "true" - if v, err := strconv.ParseFloat(settings[SettingKeyBalanceLowNotifyThreshold], 64); err == nil && v >= 0 { - result.BalanceLowNotifyThreshold = v - } - result.BalanceLowNotifyRechargeURL = settings[SettingKeyBalanceLowNotifyRechargeURL] - - // Account quota notification - result.AccountQuotaNotifyEnabled = settings[SettingKeyAccountQuotaNotifyEnabled] == "true" - if raw := strings.TrimSpace(settings[SettingKeyAccountQuotaNotifyEmails]); raw != "" { - result.AccountQuotaNotifyEmails = ParseNotifyEmails(raw) - } - if result.AccountQuotaNotifyEmails == nil { - result.AccountQuotaNotifyEmails = []NotifyEmailEntry{} - } - - return result -} - func clampAffiliateRebateRate(value float64) float64 { if math.IsNaN(value) || math.IsInf(value, 0) { return AffiliateRebateRateDefault @@ -2913,142 +1519,6 @@ func normalizeVisibleMethodSettingSource(method, source string, enabled bool) (s return normalized, nil } -func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { - raw = strings.TrimSpace(raw) - if raw == "" { - return nil - } - - var items []DefaultSubscriptionSetting - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return nil - } - - normalized := make([]DefaultSubscriptionSetting, 0, len(items)) - for _, item := range items { - if item.GroupID <= 0 || item.ValidityDays <= 0 { - continue - } - if item.ValidityDays > MaxValidityDays { - item.ValidityDays = MaxValidityDays - } - normalized = append(normalized, item) - } - - return normalized -} - -func parseProviderDefaultGrantSettings(settings map[string]string, keys authSourceDefaultKeySet) ProviderDefaultGrantSettings { - result := ProviderDefaultGrantSettings{ - Balance: defaultAuthSourceBalance, - Concurrency: defaultAuthSourceConcurrency, - Subscriptions: []DefaultSubscriptionSetting{}, - GrantOnSignup: false, - GrantOnFirstBind: false, - } - - if v, err := strconv.ParseFloat(strings.TrimSpace(settings[keys.balance]), 64); err == nil { - result.Balance = v - } - if v, err := strconv.Atoi(strings.TrimSpace(settings[keys.concurrency])); err == nil { - result.Concurrency = v - } - if items := parseDefaultSubscriptions(settings[keys.subscriptions]); items != nil { - result.Subscriptions = items - } - if raw, ok := settings[keys.grantOnSignup]; ok { - result.GrantOnSignup = raw == "true" - } - if raw, ok := settings[keys.grantOnFirstBind]; ok { - result.GrantOnFirstBind = raw == "true" - } - - return result -} - -func writeProviderDefaultGrantUpdates(updates map[string]string, keys authSourceDefaultKeySet, settings ProviderDefaultGrantSettings) { - updates[keys.balance] = strconv.FormatFloat(settings.Balance, 'f', 8, 64) - updates[keys.concurrency] = strconv.Itoa(settings.Concurrency) - - subscriptions := settings.Subscriptions - if subscriptions == nil { - subscriptions = []DefaultSubscriptionSetting{} - } - raw, err := json.Marshal(subscriptions) - if err != nil { - raw = []byte("[]") - } - updates[keys.subscriptions] = string(raw) - updates[keys.grantOnSignup] = strconv.FormatBool(settings.GrantOnSignup) - updates[keys.grantOnFirstBind] = strconv.FormatBool(settings.GrantOnFirstBind) -} - -func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettings, providerDefaults ProviderDefaultGrantSettings) ProviderDefaultGrantSettings { - result := ProviderDefaultGrantSettings{ - Balance: globalDefaults.Balance, - Concurrency: globalDefaults.Concurrency, - Subscriptions: append([]DefaultSubscriptionSetting(nil), globalDefaults.Subscriptions...), - GrantOnSignup: providerDefaults.GrantOnSignup, - GrantOnFirstBind: providerDefaults.GrantOnFirstBind, - } - - if providerDefaults.Balance != defaultAuthSourceBalance { - result.Balance = providerDefaults.Balance - } - if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency { - result.Concurrency = providerDefaults.Concurrency - } - if len(providerDefaults.Subscriptions) > 0 { - result.Subscriptions = append([]DefaultSubscriptionSetting(nil), providerDefaults.Subscriptions...) - } - - return result -} - -func parseTablePreferences(defaultPageSizeRaw, optionsRaw string) (int, []int) { - defaultPageSize := 20 - if v, err := strconv.Atoi(strings.TrimSpace(defaultPageSizeRaw)); err == nil { - defaultPageSize = v - } - - var options []int - if strings.TrimSpace(optionsRaw) != "" { - _ = json.Unmarshal([]byte(optionsRaw), &options) - } - - return normalizeTablePreferences(defaultPageSize, options) -} - -func normalizeTablePreferences(defaultPageSize int, options []int) (int, []int) { - const minPageSize = 5 - const maxPageSize = 1000 - const fallbackPageSize = 20 - - seen := make(map[int]struct{}, len(options)) - normalizedOptions := make([]int, 0, len(options)) - for _, option := range options { - if option < minPageSize || option > maxPageSize { - continue - } - if _, ok := seen[option]; ok { - continue - } - seen[option] = struct{}{} - normalizedOptions = append(normalizedOptions, option) - } - sort.Ints(normalizedOptions) - - if defaultPageSize < minPageSize || defaultPageSize > maxPageSize { - defaultPageSize = fallbackPageSize - } - - if len(normalizedOptions) == 0 { - normalizedOptions = []int{10, 20, 50} - } - - return defaultPageSize, normalizedOptions -} - // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { diff --git a/backend/internal/service/setting_validation.go b/backend/internal/service/setting_validation.go new file mode 100644 index 00000000000..3ae9fec60d9 --- /dev/null +++ b/backend/internal/service/setting_validation.go @@ -0,0 +1,291 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func normalizeLoginAgreementMode(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "checkbox": + return "checkbox" + default: + return defaultLoginAgreementMode + } +} + +func defaultLoginAgreementDocuments() []LoginAgreementDocument { + return []LoginAgreementDocument{ + { + ID: "terms", + Title: "服务条款", + ContentMD: "", + }, + { + ID: "usage-policy", + Title: "使用政策", + ContentMD: "", + }, + { + ID: "supported-regions", + Title: "支持的国家和地区", + ContentMD: "", + }, + { + ID: "service-specific-terms", + Title: "服务特定条款", + ContentMD: "", + }, + } +} + +func normalizeWeChatConnectModeSetting(raw string) string { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "mp": + return "mp" + case "mobile": + return "mobile" + default: + return "open" + } +} + +func parseWeChatConnectCapabilitySettings(settings map[string]string, enabled bool, mode string) (bool, bool, bool) { + mode = normalizeWeChatConnectModeSetting(mode) + rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] + rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] + rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] + openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" + mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" + mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" + + if openConfigured || mpConfigured || mobileConfigured { + openEnabled := strings.TrimSpace(rawOpen) == "true" + mpEnabled := strings.TrimSpace(rawMP) == "true" + mobileEnabled := strings.TrimSpace(rawMobile) == "true" + return openEnabled, mpEnabled, mobileEnabled + } + + if !enabled { + return false, false, false + } + if mode == "mp" { + return false, true, false + } + if mode == "mobile" { + return false, false, true + } + return true, false, false +} + +func normalizeWeChatConnectStoredMode(openEnabled, mpEnabled, mobileEnabled bool, mode string) string { + mode = normalizeWeChatConnectModeSetting(mode) + switch mode { + case "open": + if openEnabled { + return "open" + } + case "mp": + if mpEnabled { + return "mp" + } + case "mobile": + if mobileEnabled { + return "mobile" + } + } + switch { + case openEnabled: + return "open" + case mpEnabled: + return "mp" + case mobileEnabled: + return "mobile" + default: + return mode + } +} + +func mergeWeChatConnectCapabilitySettings(settings map[string]string, base config.WeChatConnectConfig, enabled bool, mode string) (bool, bool, bool) { + mode = normalizeWeChatConnectModeSetting(firstNonEmpty(mode, base.Mode)) + rawOpen, hasOpen := settings[SettingKeyWeChatConnectOpenEnabled] + rawMP, hasMP := settings[SettingKeyWeChatConnectMPEnabled] + rawMobile, hasMobile := settings[SettingKeyWeChatConnectMobileEnabled] + openConfigured := hasOpen && strings.TrimSpace(rawOpen) != "" + mpConfigured := hasMP && strings.TrimSpace(rawMP) != "" + mobileConfigured := hasMobile && strings.TrimSpace(rawMobile) != "" + + if openConfigured || mpConfigured || mobileConfigured { + openEnabled := strings.TrimSpace(rawOpen) == "true" + mpEnabled := strings.TrimSpace(rawMP) == "true" + mobileEnabled := strings.TrimSpace(rawMobile) == "true" + _, enabledConfigured := settings[SettingKeyWeChatConnectEnabled] + if !enabledConfigured && + enabled && + !openEnabled && + !mpEnabled && + !mobileEnabled && + (base.OpenEnabled || base.MPEnabled || base.MobileEnabled) { + return base.OpenEnabled, base.MPEnabled, base.MobileEnabled + } + return openEnabled, mpEnabled, mobileEnabled + } + if !enabled { + return false, false, false + } + if base.OpenEnabled || base.MPEnabled || base.MobileEnabled { + return base.OpenEnabled, base.MPEnabled, base.MobileEnabled + } + return parseWeChatConnectCapabilitySettings(settings, enabled, mode) +} + +func (s *SettingService) parseWeChatConnectOAuthConfig(settings map[string]string) (WeChatConnectOAuthConfig, error) { + cfg := s.effectiveWeChatConnectOAuthConfig(settings) + + if !cfg.Enabled || (!cfg.OpenEnabled && !cfg.MPEnabled) { + return WeChatConnectOAuthConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "wechat oauth is disabled") + } + if cfg.OpenEnabled { + if cfg.AppIDForMode("open") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app id not configured") + } + if cfg.AppSecretForMode("open") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth pc app secret not configured") + } + } + if cfg.MPEnabled { + if cfg.AppIDForMode("mp") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app id not configured") + } + if cfg.AppSecretForMode("mp") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth official account app secret not configured") + } + } + if cfg.MobileEnabled { + if cfg.AppIDForMode("mobile") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app id not configured") + } + if cfg.AppSecretForMode("mobile") == "" { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth mobile app secret not configured") + } + } + if v := strings.TrimSpace(cfg.RedirectURL); v != "" { + if err := config.ValidateAbsoluteHTTPURL(v); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth redirect url invalid") + } + } + if err := config.ValidateFrontendRedirectURL(cfg.FrontendRedirectURL); err != nil { + return WeChatConnectOAuthConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "wechat oauth frontend redirect url invalid") + } + return cfg, nil +} + +func (s *SettingService) weChatOAuthCapabilitiesFromSettings(settings map[string]string) (bool, bool, bool, bool) { + cfg := s.effectiveWeChatConnectOAuthConfig(settings) + if !cfg.Enabled { + return false, false, false, false + } + + openReady := cfg.OpenEnabled && cfg.AppIDForMode("open") != "" && cfg.AppSecretForMode("open") != "" + mpReady := cfg.MPEnabled && cfg.AppIDForMode("mp") != "" && cfg.AppSecretForMode("mp") != "" + mobileReady := cfg.MobileEnabled && cfg.AppIDForMode("mobile") != "" && cfg.AppSecretForMode("mobile") != "" + + return openReady || mpReady, openReady, mpReady, mobileReady +} + +func (s *SettingService) emailOAuthPublicEnabled(settings map[string]string, provider string) bool { + cfg := s.effectiveEmailOAuthConfig(settings, provider) + return cfg.Enabled && strings.TrimSpace(cfg.ClientID) != "" && strings.TrimSpace(cfg.ClientSecret) != "" +} + +func (s *SettingService) effectiveEmailOAuthConfig(settings map[string]string, provider string) config.EmailOAuthProviderConfig { + cfg := s.emailOAuthBaseConfig(provider) + switch strings.ToLower(strings.TrimSpace(provider)) { + case "github": + if raw, ok := settings[SettingKeyGitHubOAuthEnabled]; ok { + cfg.Enabled = raw == "true" + } + cfg.ClientID = firstNonEmpty(settings[SettingKeyGitHubOAuthClientID], cfg.ClientID) + cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGitHubOAuthClientSecret], cfg.ClientSecret) + cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthRedirectURL], cfg.RedirectURL) + cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGitHubOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGitHubOAuthFrontend) + case "google": + if raw, ok := settings[SettingKeyGoogleOAuthEnabled]; ok { + cfg.Enabled = raw == "true" + } + cfg.ClientID = firstNonEmpty(settings[SettingKeyGoogleOAuthClientID], cfg.ClientID) + cfg.ClientSecret = firstNonEmpty(settings[SettingKeyGoogleOAuthClientSecret], cfg.ClientSecret) + cfg.RedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthRedirectURL], cfg.RedirectURL) + cfg.FrontendRedirectURL = firstNonEmpty(settings[SettingKeyGoogleOAuthFrontendRedirectURL], cfg.FrontendRedirectURL, defaultGoogleOAuthFrontend) + } + return cfg +} + +func oidcUsePKCECompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.UsePKCEExplicit { + return base.UsePKCE + } + return true +} + +func oidcValidateIDTokenCompatibilityDefault(base config.OIDCConnectConfig) bool { + if base.ValidateIDTokenExplicit { + return base.ValidateIDToken + } + return true +} + +func oidcCompatibilityWriteDefault(base config.OIDCConnectConfig, configured bool, raw string, explicit bool, explicitValue bool) bool { + if configured { + return strings.TrimSpace(raw) == "true" + } + if explicit { + return explicitValue + } + return false +} + +func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { + if len(items) == 0 { + return nil + } + + checked := make(map[int64]struct{}, len(items)) + for _, item := range items { + if item.GroupID <= 0 { + continue + } + if _, ok := checked[item.GroupID]; ok { + return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + checked[item.GroupID] = struct{}{} + if s.defaultSubGroupReader == nil { + continue + } + + group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID) + if err != nil { + if errors.Is(err, ErrGroupNotFound) { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err) + } + if !group.IsSubscriptionType() { + return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ + "group_id": strconv.FormatInt(item.GroupID, 10), + }) + } + } + + return nil +} diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go index 11ace7bd8e0..02369b19ef5 100644 --- a/backend/internal/service/sticky_session_test.go +++ b/backend/internal/service/sticky_session_test.go @@ -122,8 +122,8 @@ func TestShouldClearStickySession(t *testing.T) { { name: "overloaded account", account: &Account{ - Status: StatusActive, - Schedulable: true, + Status: StatusActive, + Schedulable: true, OverloadUntil: &future, }, requestedModel: "", diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index f84e6f0ab06..208a05dbdd9 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -8,8 +8,6 @@ import ( "encoding/base64" "encoding/hex" "fmt" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "image" "image/color" stddraw "image/draw" @@ -24,6 +22,9 @@ import ( "sync" "time" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + xdraw "golang.org/x/image/draw" "golang.org/x/sync/singleflight" ) diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 775dd6028df..0fef0f87d8c 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -202,7 +202,7 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } -func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } +func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { out := make([]UserAuthIdentityRecord, len(m.identities)) copy(out, m.identities)