diff --git a/model/log.go b/model/log.go index 9203ff28be1..7e315aeb89e 100644 --- a/model/log.go +++ b/model/log.go @@ -253,17 +253,23 @@ func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) } type RecordTaskBillingLogParams struct { - UserId int - LogType int - Content string - ChannelId int - ModelName string - Quota int - TokenId int - Group string - Other map[string]interface{} + UserId int + LogType int + Content string + ChannelId int + ModelName string + Quota int + PromptTokens int + CompletionTokens int + TokenId int + Group string + Other map[string]interface{} } +// RecordTaskBillingLog 仅写入 logs 表的一条 task 计费记录(消费 / 退款)。 +// 注意:本函数 *不会* 自动调整 quota_data 统计 —— 因为 task 的"部分退款 + 仍需补全 token" +// 等场景下,金额方向(refund)和 token 方向(增加)并不对称,无法只靠 LogType 推断符号。 +// quota_data 的同步由上层 service 显式调用 LogQuotaDataAdjust 完成。 func RecordTaskBillingLog(params RecordTaskBillingLogParams) { if params.LogType == LogTypeConsume && !common.LogConsumeEnabled { return @@ -276,18 +282,20 @@ func RecordTaskBillingLog(params RecordTaskBillingLogParams) { } } log := &Log{ - UserId: params.UserId, - Username: username, - CreatedAt: common.GetTimestamp(), - Type: params.LogType, - Content: params.Content, - TokenName: tokenName, - ModelName: params.ModelName, - Quota: params.Quota, - ChannelId: params.ChannelId, - TokenId: params.TokenId, - Group: params.Group, - Other: common.MapToJsonStr(params.Other), + UserId: params.UserId, + Username: username, + CreatedAt: common.GetTimestamp(), + Type: params.LogType, + Content: params.Content, + TokenName: tokenName, + ModelName: params.ModelName, + Quota: params.Quota, + PromptTokens: params.PromptTokens, + CompletionTokens: params.CompletionTokens, + ChannelId: params.ChannelId, + TokenId: params.TokenId, + Group: params.Group, + Other: common.MapToJsonStr(params.Other), } err := LOG_DB.Create(log).Error if err != nil { diff --git a/model/usedata.go b/model/usedata.go index f0ea055ae39..26b0176a8a6 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -34,11 +34,11 @@ func UpdateQuotaData() { var CacheQuotaData = make(map[string]*QuotaData) var CacheQuotaDataLock = sync.Mutex{} -func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int) { +func logQuotaDataCache(userId int, username string, modelName string, quota int, createdAt int64, tokenUsed int, countDelta int) { key := fmt.Sprintf("%d-%s-%s-%d", userId, username, modelName, createdAt) quotaData, ok := CacheQuotaData[key] if ok { - quotaData.Count += 1 + quotaData.Count += countDelta quotaData.Quota += quota quotaData.TokenUsed += tokenUsed } else { @@ -47,7 +47,7 @@ func logQuotaDataCache(userId int, username string, modelName string, quota int, Username: username, ModelName: modelName, CreatedAt: createdAt, - Count: 1, + Count: countDelta, Quota: quota, TokenUsed: tokenUsed, } @@ -61,7 +61,22 @@ func LogQuotaData(userId int, username string, modelName string, quota int, crea CacheQuotaDataLock.Lock() defer CacheQuotaDataLock.Unlock() - logQuotaDataCache(userId, username, modelName, quota, createdAt, tokenUsed) + logQuotaDataCache(userId, username, modelName, quota, createdAt, tokenUsed, 1) +} + +// LogQuotaDataAdjust 用于异步任务的退款/补扣场景:仅调整 quota / token_used,不变化 count。 +// quotaDelta、tokenDelta 支持正负:退款传负值,补扣传正值。 +// - 钱不动只补 token 统计(如 token 重算 delta=0 但 totalTokens>0)也可以直接调用。 +// - 调用前置开关由调用方负责(一般为 common.DataExportEnabled)。 +func LogQuotaDataAdjust(userId int, username string, modelName string, quotaDelta int, createdAt int64, tokenDelta int) { + if quotaDelta == 0 && tokenDelta == 0 { + return + } + createdAt = createdAt - (createdAt % 3600) + + CacheQuotaDataLock.Lock() + defer CacheQuotaDataLock.Unlock() + logQuotaDataCache(userId, username, modelName, quotaDelta, createdAt, tokenDelta, 0) } func SaveQuotaDataCache() { diff --git a/model/user.go b/model/user.go index 79e63e8fd59..ab6a6bd36d7 100644 --- a/model/user.go +++ b/model/user.go @@ -960,6 +960,24 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { updateUserUsedQuotaAndRequestCount(id, quota, 1) } +// UpdateUserUsedQuotaDelta 仅调整 used_quota(不动 request_count),支持负值。 +// 用于异步任务的退款 / 补扣场景,避免 request_count 被错误累加: +// - 退款:传入负值,used_quota 回退; +// - 补扣:传入正值,used_quota 增加。 +// +// 与 UpdateUserUsedQuotaAndRequestCount 走同一批量通道(BatchUpdateTypeUsedQuota), +// 底层 SQL 是 `used_quota + ?`,天然支持正负。 +func UpdateUserUsedQuotaDelta(id int, delta int) { + if delta == 0 { + return + } + if common.BatchUpdateEnabled { + addNewRecord(BatchUpdateTypeUsedQuota, id, delta) + return + } + updateUserUsedQuota(id, delta) +} + func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ diff --git a/relay/channel/task/doubao/adaptor.go b/relay/channel/task/doubao/adaptor.go index a6dabb5f108..7ccf2b4e99e 100644 --- a/relay/channel/task/doubao/adaptor.go +++ b/relay/channel/task/doubao/adaptor.go @@ -114,9 +114,42 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { } // ValidateRequestAndSetAction parses body, validates fields and sets default action. +// 额外职责:对登记了二维定价的模型(seedance 2.0 系列),前置拦截上游 +// "暂不支持"的档位组合(如 pro 的 1080p+含视频、fast 的 1080p), +// 避免预扣后还要走上游失败 + 退款流程。 func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { - // Accept only POST /v1/video/generations as "generate" action. - return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate) + if err := relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate); err != nil { + return err + } + + req, err := relaycommon.GetTaskRequest(c) + if err != nil { + return nil + } + modelName := req.Model + if modelName == "" { + return nil + } + if _, supported := ResolveBillingRatios(modelName, extractResolution(req.Metadata), hasVideoInMetadata(req.Metadata)); !supported { + return service.TaskErrorWrapperLocal( + fmt.Errorf("model %s does not support the requested resolution / video-input combination", modelName), + "unsupported_resolution_combination", + http.StatusBadRequest, + ) + } + return nil +} + +// extractResolution 从请求 metadata 中读取 resolution 字段;未显式指定时返回空串, +// 由 normalizeResolution 在定价查询时回退到豆包默认分辨率(720p)。 +func extractResolution(metadata map[string]interface{}) string { + if metadata == nil { + return "" + } + if v, ok := metadata["resolution"].(string); ok { + return v + } + return "" } // BuildRequestURL constructs the upstream URL. @@ -132,18 +165,20 @@ func (a *TaskAdaptor) BuildRequestHeader(_ *gin.Context, req *http.Request, _ *r return nil } -// EstimateBilling 检测请求 metadata 中是否包含视频输入,返回视频折扣 OtherRatio。 +// EstimateBilling 根据请求的输出分辨率与是否包含视频输入, +// 返回相对基线档的二维 OtherRatios(resolution / video_input)。 +// "暂不支持"的组合已经在 ValidateRequestAndSetAction 阶段被拦截, +// 这里只会命中合法档位。 func (a *TaskAdaptor) EstimateBilling(c *gin.Context, info *relaycommon.RelayInfo) map[string]float64 { req, err := relaycommon.GetTaskRequest(c) if err != nil { return nil } - if hasVideoInMetadata(req.Metadata) { - if ratio, ok := GetVideoInputRatio(info.OriginModelName); ok { - return map[string]float64{"video_input": ratio} - } + ratios, supported := ResolveBillingRatios(info.OriginModelName, extractResolution(req.Metadata), hasVideoInMetadata(req.Metadata)) + if !supported { + return nil } - return nil + return ratios } // hasVideoInMetadata 直接检查 metadata 的 content 数组是否包含 video_url 条目, diff --git a/relay/channel/task/doubao/constants.go b/relay/channel/task/doubao/constants.go index d65773d3068..ccd4a025812 100644 --- a/relay/channel/task/doubao/constants.go +++ b/relay/channel/task/doubao/constants.go @@ -1,5 +1,7 @@ package doubao +import "strings" + var ModelList = []string{ "doubao-seedance-1-0-pro-250528", "doubao-seedance-1-0-lite-t2v", @@ -11,15 +13,123 @@ var ModelList = []string{ var ChannelName = "doubao-video" -// videoInputRatioMap 视频输入折扣比率(含视频单价 / 不含视频单价)。 -// 管理员应将 ModelRatio 设置为"不含视频"的较高费率, -// 系统在检测到视频输入时自动乘以此折扣。 -var videoInputRatioMap = map[string]float64{ - "doubao-seedance-2-0-260128": 28.0 / 46.0, // ~0.6087 - "doubao-seedance-2-0-fast-260128": 22.0 / 37.0, // ~0.5946 +// Resolution 档位的标准化常量。 +const ( + resolution480p = "480p" + resolution720p = "720p" + resolution1080p = "1080p" +) + +// pricingKey 由 (模型名, 输出分辨率, 是否包含视频输入) 三元组唯一定位一档实际单价。 +type pricingKey struct { + Model string + Resolution string + WithVideoIn bool +} + +// pricingRatioMap 登记"该档位实际单价 / 该模型基线单价(480p-720p 不含视频档)"的比例。 +// 管理员应以每个模型的"基线档"ModelRatio 作为 1.0 基准, +// 当请求命中非基线档(1080p 或包含视频输入)时,本表给出对应倍率。 +// +// 豆包 2026-01-28 价格表: +// +// doubao-seedance-2-0-260128(基线 46 元 / 百万 tokens): +// +// 480p / 720p 不含视频:46 → 1.0 +// 480p / 720p 含视频 :28 → 28/46 +// 1080p 不含视频 :51 → 51/46 +// 1080p 含视频 :31 → 31/46 +// +// doubao-seedance-2-0-fast-260128(基线 37 元 / 百万 tokens): +// +// 480p / 720p 不含视频:37 → 1.0 +// 480p / 720p 含视频 :22 → 22/37 +// 1080p :暂不支持 +var pricingRatioMap = map[pricingKey]float64{ + {Model: "doubao-seedance-2-0-260128", Resolution: resolution480p, WithVideoIn: false}: 1.0, + {Model: "doubao-seedance-2-0-260128", Resolution: resolution720p, WithVideoIn: false}: 1.0, + {Model: "doubao-seedance-2-0-260128", Resolution: resolution480p, WithVideoIn: true}: 28.0 / 46.0, + {Model: "doubao-seedance-2-0-260128", Resolution: resolution720p, WithVideoIn: true}: 28.0 / 46.0, + {Model: "doubao-seedance-2-0-260128", Resolution: resolution1080p, WithVideoIn: false}: 51.0 / 46.0, + {Model: "doubao-seedance-2-0-260128", Resolution: resolution1080p, WithVideoIn: true}: 31.0 / 46.0, + + {Model: "doubao-seedance-2-0-fast-260128", Resolution: resolution480p, WithVideoIn: false}: 1.0, + {Model: "doubao-seedance-2-0-fast-260128", Resolution: resolution720p, WithVideoIn: false}: 1.0, + {Model: "doubao-seedance-2-0-fast-260128", Resolution: resolution480p, WithVideoIn: true}: 22.0 / 37.0, + {Model: "doubao-seedance-2-0-fast-260128", Resolution: resolution720p, WithVideoIn: true}: 22.0 / 37.0, } -func GetVideoInputRatio(modelName string) (float64, bool) { - r, ok := videoInputRatioMap[modelName] - return r, ok +// hasPricingConfig 指明哪些模型走二维倍率计算; +// 未登记的模型(如 seedance-1-x 系列)跳过倍率处理,保持原 ModelRatio 全额计费。 +var hasPricingConfig = map[string]bool{ + "doubao-seedance-2-0-260128": true, + "doubao-seedance-2-0-fast-260128": true, +} + +// normalizeResolution 将 "480P"/"1080p" 等大小写变体统一为标准形式。 +// 未指定或未知分辨率一律回退为 720p(豆包 seedance 默认输出分辨率)。 +func normalizeResolution(r string) string { + switch strings.ToLower(strings.TrimSpace(r)) { + case resolution480p: + return resolution480p + case resolution1080p: + return resolution1080p + case "", resolution720p: + return resolution720p + default: + return resolution720p + } +} + +// ResolveBillingRatios 根据模型、输出分辨率与是否含视频输入, +// 返回相对"基线档(480p/720p 不含视频)"的二维倍率 (resolution / video_input)。 +// +// - supported == false 表示当前组合被上游标注为"暂不支持",调用方应直接拒绝请求。 +// - 对未登记定价配置的模型返回 (nil, true),表示"不需要任何折扣,走基础 ModelRatio"。 +// +// 拆维语义(保证两维乘积严格等于该档位 totalRatio): +// +// - resolution = pricingRatioMap[(model, res, 不含视频)] // 横向:分辨率溢价(以"不含视频"档为基准) +// - video_input = pricingRatioMap[(model, res, 含视频)] / resolution +// // 纵向:同分辨率下"含视频"相对"不含视频"的折扣 +// +// 验证(pro 2.0 为例): +// +// 480p/720p, 不含视频:res=1.0, vid=1.0 → 1.0 +// 480p/720p, 含视频 :res=1.0, vid=28/46 → 28/46 ✓ +// 1080p, 不含视频:res=51/46, vid=1.0 → 51/46 ✓ +// 1080p, 含视频 :res=51/46, vid=(31/46)/(51/46)=31/51 → 31/46 ✓ +func ResolveBillingRatios(modelName, resolution string, withVideoIn bool) (map[string]float64, bool) { + if !hasPricingConfig[modelName] { + return nil, true + } + res := normalizeResolution(resolution) + totalRatio, ok := pricingRatioMap[pricingKey{Model: modelName, Resolution: res, WithVideoIn: withVideoIn}] + if !ok { + return nil, false + } + + // 该分辨率下"不含视频"档的绝对倍率(横向基准)。 + baseForRes, hasBase := pricingRatioMap[pricingKey{Model: modelName, Resolution: res, WithVideoIn: false}] + if !hasBase { + // 理论上不会发生:有含视频档就必有不含视频档(本体定价表保持这条约束)。 + // 兜底按 totalRatio 单键上报,避免返回错误。 + ratios := map[string]float64{} + if totalRatio != 1.0 { + ratios["pricing_tier"] = totalRatio + } + return ratios, true + } + + ratios := map[string]float64{} + if baseForRes != 1.0 { + ratios["resolution"] = baseForRes + } + if withVideoIn && baseForRes > 0 { + videoRatio := totalRatio / baseForRes + if videoRatio != 1.0 { + ratios["video_input"] = videoRatio + } + } + return ratios, true } diff --git a/service/task_billing.go b/service/task_billing.go index 6cf7a965c8e..2be949a3414 100644 --- a/service/task_billing.go +++ b/service/task_billing.go @@ -147,44 +147,89 @@ func taskModelName(task *model.Task) string { return task.Properties.OriginModelName } +// taskAdjustQuotaData 同步调整 quota_data 统计(/api/data/ 数据来源)。 +// quotaDelta / tokenDelta 均为有符号增量: +// - 失败退款:quotaDelta=-quota, tokenDelta=0(原始记录 token 就是 0,对称回退即可) +// - 补扣 / 部分退款 + 真实 token:quotaDelta=±delta, tokenDelta=+totalTokens +// (token 之所以 *永远是 +*:原始 LogTaskConsumption 写入的 token 都是 0, +// 现在要补到 totalTokens,所以是单向增量,不随 quota 取号) +func taskAdjustQuotaData(task *model.Task, quotaDelta, tokenDelta int) { + if !common.DataExportEnabled { + return + } + if quotaDelta == 0 && tokenDelta == 0 { + return + } + username, _ := model.GetUsernameById(task.UserId, false) + model.LogQuotaDataAdjust(task.UserId, username, taskModelName(task), quotaDelta, common.GetTimestamp(), tokenDelta) +} + // RefundTaskQuota 统一的任务失败退款逻辑。 // 当异步任务失败时,将预扣的 quota 退还给用户(支持钱包和订阅),并退还令牌额度。 +// +// 守恒保证: +// - 用户表「总额度」(Quota + UsedQuota) 不变 — Quota +quota / UsedQuota -quota; +// - 令牌「剩余额度」回退(IncreaseTokenQuota 内部自动 RemainQuota+ / UsedQuota-); +// - 渠道用量同步回退; +// - quota_data 统计同步反向回退(token 字段保持 0 — 失败任务无实际 token 消耗)。 func RefundTaskQuota(ctx context.Context, task *model.Task, reason string) { quota := task.Quota if quota == 0 { return } - // 1. 退还资金来源(钱包或订阅) + // 1. 退还资金来源(钱包或订阅):影响 User.Quota if err := taskAdjustFunding(task, -quota); err != nil { logger.LogWarn(ctx, fmt.Sprintf("退还资金来源失败 task %s: %s", task.TaskID, err.Error())) return } - // 2. 退还令牌额度 + // 2. 退还令牌额度:影响 Token.RemainQuota / Token.UsedQuota taskAdjustTokenQuota(ctx, task, -quota) - // 3. 记录日志 + // 3. 回退用户「已用额度」与渠道用量,使总额度守恒 + model.UpdateUserUsedQuotaDelta(task.UserId, -quota) + if task.ChannelId > 0 { + model.UpdateChannelUsedQuota(task.ChannelId, -quota) + } + + // 4. 反向调整 /api/data/ 数据看板统计(quota -quota、tokens 不动) + taskAdjustQuotaData(task, -quota, 0) + + // 5. 记录退款日志 other := taskBillingOther(task) other["task_id"] = task.TaskID other["reason"] = reason model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ - UserId: task.UserId, - LogType: model.LogTypeRefund, - Content: "", - ChannelId: task.ChannelId, - ModelName: taskModelName(task), - Quota: quota, - TokenId: task.PrivateData.TokenId, - Group: task.Group, - Other: other, + UserId: task.UserId, + LogType: model.LogTypeRefund, + Content: "", + ChannelId: task.ChannelId, + ModelName: taskModelName(task), + Quota: quota, + PromptTokens: 0, + CompletionTokens: 0, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, }) } // RecalculateTaskQuota 通用的异步差额结算。 // actualQuota 是任务完成后的实际应扣额度,与预扣额度 (task.Quota) 做差额结算。 +// totalTokens 为本次任务实际消耗的 token 数(视频生成模型 input=0,故全部计入 CompletionTokens); +// +// 若上游未返回 token 用量则传 0,不会污染统计。 +// // reason 用于日志记录(例如 "token重算" 或 "adaptor调整")。 -func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, reason string) { +// +// 守恒保证(与 RefundTaskQuota 对称): +// - 补扣 (delta>0):User.Quota -delta / User.UsedQuota +delta;Channel.UsedQuota +delta; +// - 退还 (delta<0):User.Quota +|delta| / User.UsedQuota -|delta|;Channel.UsedQuota -|delta|; +// - 不增 request_count(这只是结算,不是新请求); +// - quota_data 统计 quota 跟随 delta 同向变化,token_used 单向 +totalTokens +// (原始 LogTaskConsumption 时 token 永远是 0,需要在终态补到 totalTokens)。 +func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int, totalTokens int, reason string) { if actualQuota <= 0 { return } @@ -194,6 +239,11 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int if quotaDelta == 0 { logger.LogInfo(ctx, fmt.Sprintf("任务 %s 预扣费准确(%s,%s)", task.TaskID, logger.LogQuota(actualQuota), reason)) + // 钱不动也要把 token 用量补到统计里:原始 LogTaskConsumption 时 token 一律是 0, + // 现在拿到了上游真实 totalTokens,补一行统计调整即可。 + if totalTokens > 0 { + taskAdjustQuotaData(task, 0, totalTokens) + } return } @@ -205,15 +255,24 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int reason, )) - // 调整资金来源 + // 调整资金来源(钱包 or 订阅):影响 User.Quota if err := taskAdjustFunding(task, quotaDelta); err != nil { logger.LogError(ctx, fmt.Sprintf("差额结算资金调整失败 task %s: %s", task.TaskID, err.Error())) return } - // 调整令牌额度 + // 调整令牌额度:Token.RemainQuota / Token.UsedQuota 内部已对称 taskAdjustTokenQuota(ctx, task, quotaDelta) + // User.UsedQuota 与 Channel.UsedQuota 跟随 delta 同向变化(不动 request_count) + model.UpdateUserUsedQuotaDelta(task.UserId, quotaDelta) + if task.ChannelId > 0 { + model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) + } + + // /api/data/ 统计:quota 同向变化、token_used 单向 +totalTokens + taskAdjustQuotaData(task, quotaDelta, totalTokens) + task.Quota = actualQuota var logType int @@ -221,8 +280,6 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int if quotaDelta > 0 { logType = model.LogTypeConsume logQuota = quotaDelta - model.UpdateUserUsedQuotaAndRequestCount(task.UserId, quotaDelta) - model.UpdateChannelUsedQuota(task.ChannelId, quotaDelta) } else { logType = model.LogTypeRefund logQuota = -quotaDelta @@ -231,6 +288,9 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int other["task_id"] = task.TaskID other["pre_consumed_quota"] = preConsumedQuota other["actual_quota"] = actualQuota + if totalTokens > 0 { + other["total_tokens"] = totalTokens + } model.RecordTaskBillingLog(model.RecordTaskBillingLogParams{ UserId: task.UserId, LogType: logType, @@ -238,9 +298,13 @@ func RecalculateTaskQuota(ctx context.Context, task *model.Task, actualQuota int ChannelId: task.ChannelId, ModelName: taskModelName(task), Quota: logQuota, - TokenId: task.PrivateData.TokenId, - Group: task.Group, - Other: other, + // 视频/图像类异步任务上游不区分 input/output token:input=0,output=total + // (参见 doubao seedance 文档:total_tokens = completion_tokens) + PromptTokens: 0, + CompletionTokens: totalTokens, + TokenId: task.PrivateData.TokenId, + Group: task.Group, + Other: other, }) } @@ -297,5 +361,5 @@ func RecalculateTaskQuotaByTokens(ctx context.Context, task *model.Task, totalTo actualQuota := int(float64(totalTokens) * modelRatio * finalGroupRatio * otherMultiplier) reason := fmt.Sprintf("token重算:tokens=%d, modelRatio=%.2f, groupRatio=%.2f, otherMultiplier=%.4f", totalTokens, modelRatio, finalGroupRatio, otherMultiplier) - RecalculateTaskQuota(ctx, task, actualQuota, reason) + RecalculateTaskQuota(ctx, task, actualQuota, totalTokens, reason) } diff --git a/service/task_billing_test.go b/service/task_billing_test.go index 39cb8f1da1a..b5eb45bc644 100644 --- a/service/task_billing_test.go +++ b/service/task_billing_test.go @@ -44,6 +44,7 @@ func TestMain(m *testing.M) { &model.Channel{}, &model.TopUp{}, &model.UserSubscription{}, + &model.QuotaData{}, ); err != nil { panic("failed to migrate: " + err.Error()) } @@ -65,6 +66,11 @@ func truncate(t *testing.T) { model.DB.Exec("DELETE FROM channels") model.DB.Exec("DELETE FROM top_ups") model.DB.Exec("DELETE FROM user_subscriptions") + model.DB.Exec("DELETE FROM quota_data") + // 把内存里残留的 cache 也清掉,避免上一用例的统计污染下一用例 + model.CacheQuotaDataLock.Lock() + model.CacheQuotaData = make(map[string]*model.QuotaData) + model.CacheQuotaDataLock.Unlock() }) } @@ -74,6 +80,21 @@ func seedUser(t *testing.T, id int, quota int) { require.NoError(t, model.DB.Create(user).Error) } +// seedUserWithUsed 同 seedUser,但允许设置 used_quota / request_count 初值, +// 用于验证退款/补扣时 used_quota 守恒、request_count 不被污染。 +func seedUserWithUsed(t *testing.T, id int, quota int, used int, requestCount int) { + t.Helper() + user := &model.User{ + Id: id, + Username: "test_user", + Quota: quota, + UsedQuota: used, + RequestCount: requestCount, + Status: common.UserStatusEnabled, + } + require.NoError(t, model.DB.Create(user).Error) +} + func seedToken(t *testing.T, id int, userId int, key string, remainQuota int) { t.Helper() token := &model.Token{ @@ -108,6 +129,19 @@ func seedChannel(t *testing.T, id int) { require.NoError(t, model.DB.Create(ch).Error) } +// seedChannelWithUsed 在创建渠道的同时写入 used_quota,用于验证渠道用量统计同步守恒。 +func seedChannelWithUsed(t *testing.T, id int, usedQuota int64) { + t.Helper() + ch := &model.Channel{ + Id: id, + Name: "test_channel", + Key: "sk-test", + Status: common.ChannelStatusEnabled, + UsedQuota: usedQuota, + } + require.NoError(t, model.DB.Create(ch).Error) +} + func makeTask(userId, channelId, quota, tokenId int, billingSource string, subscriptionId int) *model.Task { return &model.Task{ TaskID: "task_" + time.Now().Format("150405.000"), @@ -146,6 +180,27 @@ func getUserQuota(t *testing.T, id int) int { return user.Quota } +func getUserUsedQuota(t *testing.T, id int) int { + t.Helper() + var user model.User + require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&user).Error) + return user.UsedQuota +} + +func getUserRequestCount(t *testing.T, id int) int { + t.Helper() + var user model.User + require.NoError(t, model.DB.Select("request_count").Where("id = ?", id).First(&user).Error) + return user.RequestCount +} + +func getChannelUsedQuota(t *testing.T, id int) int64 { + t.Helper() + var ch model.Channel + require.NoError(t, model.DB.Select("used_quota").Where("id = ?", id).First(&ch).Error) + return ch.UsedQuota +} + func getTokenRemainQuota(t *testing.T, id int) int { t.Helper() var token model.Token @@ -193,25 +248,42 @@ func TestRefundTaskQuota_Wallet(t *testing.T) { ctx := context.Background() const userID, tokenID, channelID = 1, 1, 1 - const initQuota, preConsumed = 10000, 3000 + // 模拟「LogTaskConsumption 已记账后任务失败」的真实状态: + // - 钱包预扣了 preConsumed → User.Quota=initQuota(已减完)/ UsedQuota=preConsumed + // - request_count 已 +1 + // - 渠道 used_quota 已 +preConsumed + const walletAfterPre, preConsumed = 7000, 3000 + const userInitTotal = walletAfterPre + preConsumed // 总额度(守恒目标) const tokenRemain = 5000 + const tokenUsedAfterPre = preConsumed + const requestCountBefore = 1 - seedUser(t, userID, initQuota) + seedUserWithUsed(t, userID, walletAfterPre, preConsumed, requestCountBefore) seedToken(t, tokenID, userID, "sk-test-key", tokenRemain) - seedChannel(t, channelID) + seedChannelWithUsed(t, channelID, int64(preConsumed)) + + // 把 token 的 used_quota 也对齐到 preConsumed(模拟 DecreaseTokenQuota 的副作用) + require.NoError(t, model.DB.Model(&model.Token{}).Where("id = ?", tokenID).Update("used_quota", tokenUsedAfterPre).Error) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) RefundTaskQuota(ctx, task, "task failed: upstream error") - // User quota should increase by preConsumed - assert.Equal(t, initQuota+preConsumed, getUserQuota(t, userID)) + // 用户表:Quota 回退 + UsedQuota 回退 → 总额度守恒 + assert.Equal(t, walletAfterPre+preConsumed, getUserQuota(t, userID)) + assert.Equal(t, 0, getUserUsedQuota(t, userID)) + assert.Equal(t, userInitTotal, getUserQuota(t, userID)+getUserUsedQuota(t, userID)) + // request_count 不应被退款污染 + assert.Equal(t, requestCountBefore, getUserRequestCount(t, userID)) - // Token remain_quota should increase, used_quota should decrease + // 令牌:剩余额度回涨;已用额度回到 0 assert.Equal(t, tokenRemain+preConsumed, getTokenRemainQuota(t, tokenID)) - assert.Equal(t, -preConsumed, getTokenUsedQuota(t, tokenID)) + assert.Equal(t, 0, getTokenUsedQuota(t, tokenID)) + + // 渠道用量:回退到预扣前 + assert.Equal(t, int64(0), getChannelUsedQuota(t, channelID)) - // A refund log should be created + // 退款日志 log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) @@ -298,32 +370,44 @@ func TestRecalculate_PositiveDelta(t *testing.T) { ctx := context.Background() const userID, tokenID, channelID = 10, 10, 10 - const initQuota, preConsumed = 10000, 2000 - const actualQuota = 3000 // under-charged by 1000 + // 模拟预扣后的真实状态 + const walletAfterPre, preConsumed = 8000, 2000 + const actualQuota = 3000 // under-charged by 1000 (need to charge an extra 1000) + const userInitTotal = walletAfterPre + preConsumed const tokenRemain = 5000 + const requestCountBefore = 1 - seedUser(t, userID, initQuota) + seedUserWithUsed(t, userID, walletAfterPre, preConsumed, requestCountBefore) seedToken(t, tokenID, userID, "sk-recalc-pos", tokenRemain) - seedChannel(t, channelID) + seedChannelWithUsed(t, channelID, int64(preConsumed)) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) - RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + RecalculateTaskQuota(ctx, task, actualQuota, 0, "adaptor adjustment") - // User quota should decrease by the delta (1000 additional charge) - assert.Equal(t, initQuota-(actualQuota-preConsumed), getUserQuota(t, userID)) + delta := actualQuota - preConsumed - // Token should also be charged the delta - assert.Equal(t, tokenRemain-(actualQuota-preConsumed), getTokenRemainQuota(t, tokenID)) + // 用户表:Quota -delta、UsedQuota +delta,总额度守恒 + assert.Equal(t, walletAfterPre-delta, getUserQuota(t, userID)) + assert.Equal(t, preConsumed+delta, getUserUsedQuota(t, userID)) + assert.Equal(t, userInitTotal, getUserQuota(t, userID)+getUserUsedQuota(t, userID)) + // request_count 不被结算污染 + assert.Equal(t, requestCountBefore, getUserRequestCount(t, userID)) - // task.Quota should be updated to actualQuota + // 令牌 + assert.Equal(t, tokenRemain-delta, getTokenRemainQuota(t, tokenID)) + + // 渠道用量随补扣同向变化 + assert.Equal(t, int64(actualQuota), getChannelUsedQuota(t, channelID)) + + // task.Quota 落到 actualQuota assert.Equal(t, actualQuota, task.Quota) - // Log type should be Consume (additional charge) + // 日志记 Consume + delta log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeConsume, log.Type) - assert.Equal(t, actualQuota-preConsumed, log.Quota) + assert.Equal(t, delta, log.Quota) } func TestRecalculate_NegativeDelta(t *testing.T) { @@ -331,32 +415,42 @@ func TestRecalculate_NegativeDelta(t *testing.T) { ctx := context.Background() const userID, tokenID, channelID = 11, 11, 11 - const initQuota, preConsumed = 10000, 5000 + const walletAfterPre, preConsumed = 5000, 5000 const actualQuota = 3000 // over-charged by 2000 + const userInitTotal = walletAfterPre + preConsumed const tokenRemain = 5000 + const requestCountBefore = 1 - seedUser(t, userID, initQuota) + seedUserWithUsed(t, userID, walletAfterPre, preConsumed, requestCountBefore) seedToken(t, tokenID, userID, "sk-recalc-neg", tokenRemain) - seedChannel(t, channelID) + seedChannelWithUsed(t, channelID, int64(preConsumed)) task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) - RecalculateTaskQuota(ctx, task, actualQuota, "adaptor adjustment") + RecalculateTaskQuota(ctx, task, actualQuota, 0, "adaptor adjustment") - // User quota should increase by abs(delta) = 2000 (refund overpayment) - assert.Equal(t, initQuota+(preConsumed-actualQuota), getUserQuota(t, userID)) + refund := preConsumed - actualQuota - // Token should be refunded the difference - assert.Equal(t, tokenRemain+(preConsumed-actualQuota), getTokenRemainQuota(t, tokenID)) + // 用户表:Quota +refund、UsedQuota -refund,总额度守恒 + assert.Equal(t, walletAfterPre+refund, getUserQuota(t, userID)) + assert.Equal(t, preConsumed-refund, getUserUsedQuota(t, userID)) + assert.Equal(t, userInitTotal, getUserQuota(t, userID)+getUserUsedQuota(t, userID)) + assert.Equal(t, requestCountBefore, getUserRequestCount(t, userID)) + + // 令牌:剩余额度增加 + assert.Equal(t, tokenRemain+refund, getTokenRemainQuota(t, tokenID)) - // task.Quota updated + // 渠道用量同步退还 + assert.Equal(t, int64(actualQuota), getChannelUsedQuota(t, channelID)) + + // task.Quota 落到 actualQuota assert.Equal(t, actualQuota, task.Quota) - // Log type should be Refund + // 日志记 Refund + refund log := getLastLog(t) require.NotNil(t, log) assert.Equal(t, model.LogTypeRefund, log.Type) - assert.Equal(t, preConsumed-actualQuota, log.Quota) + assert.Equal(t, refund, log.Quota) } func TestRecalculate_ZeroDelta(t *testing.T) { @@ -370,7 +464,7 @@ func TestRecalculate_ZeroDelta(t *testing.T) { task := makeTask(userID, 0, preConsumed, 0, BillingSourceWallet, 0) - RecalculateTaskQuota(ctx, task, preConsumed, "exact match") + RecalculateTaskQuota(ctx, task, preConsumed, 0, "exact match") // No change to user quota assert.Equal(t, initQuota, getUserQuota(t, userID)) @@ -390,13 +484,216 @@ func TestRecalculate_ActualQuotaZero(t *testing.T) { task := makeTask(userID, 0, 5000, 0, BillingSourceWallet, 0) - RecalculateTaskQuota(ctx, task, 0, "zero actual") + RecalculateTaskQuota(ctx, task, 0, 0, "zero actual") // No change (early return) assert.Equal(t, initQuota, getUserQuota(t, userID)) assert.Equal(t, int64(0), countLogs(t)) } +// TestQuotaData_RefundFullyReverts 端到端验证: +// 1. LogTaskConsumption 的 quota_data 进入 cache (+pre, count=1, tokens=0); +// 2. 任务失败触发 RefundTaskQuota → 反向 adjust 写入 cache; +// 3. 落库后该 hour bucket 的 quota / count / tokens 全部归零(守恒)。 +func TestQuotaData_RefundFullyReverts(t *testing.T) { + truncate(t) + ctx := context.Background() + + prevDataExport := common.DataExportEnabled + prevLogConsume := common.LogConsumeEnabled + common.DataExportEnabled = true + common.LogConsumeEnabled = true + t.Cleanup(func() { + common.DataExportEnabled = prevDataExport + common.LogConsumeEnabled = prevLogConsume + }) + + const userID, tokenID, channelID = 40, 40, 40 + const walletAfterPre, preConsumed = 7000, 3000 + + seedUserWithUsed(t, userID, walletAfterPre, preConsumed, 1) + seedToken(t, tokenID, userID, "sk-qdata-refund", 5000) + seedChannelWithUsed(t, channelID, int64(preConsumed)) + + // 模拟 LogTaskConsumption 已经为本任务写过一笔正向 quota_data(count=1, quota=preConsumed) + username, err := model.GetUsernameById(userID, false) + require.NoError(t, err) + model.LogQuotaData(userID, username, "test-model", preConsumed, time.Now().Unix(), 0) + + // 任务失败 → 退款(内部应同步反向 adjust quota_data) + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + RefundTaskQuota(ctx, task, "task failed") + + // RefundTaskQuota 内部 RecordTaskBillingLog 走 gopool.Go 异步刷 quota_data, + // 等待最多 1 秒让 goroutine 把 cache 写完再 flush。 + require.Eventually(t, func() bool { + model.CacheQuotaDataLock.Lock() + defer model.CacheQuotaDataLock.Unlock() + // 反向 adjust 落入 cache 后,quota / token_used 净值应当为 0 + for _, qd := range model.CacheQuotaData { + if qd.UserID != userID || qd.ModelName != "test-model" { + continue + } + return qd.Quota == 0 && qd.TokenUsed == 0 + } + return false + }, time.Second, 20*time.Millisecond) + + model.SaveQuotaDataCache() + + // 落库后断言:该 user/model 维度下 sum(quota) 和 sum(token_used) 都应为 0(守恒) + var sumQuota, sumTokens int64 + require.NoError(t, model.DB.Table("quota_data"). + Select("COALESCE(SUM(quota), 0)"). + Where("user_id = ? and model_name = ?", userID, "test-model"). + Scan(&sumQuota).Error) + require.NoError(t, model.DB.Table("quota_data"). + Select("COALESCE(SUM(token_used), 0)"). + Where("user_id = ? and model_name = ?", userID, "test-model"). + Scan(&sumTokens).Error) + assert.Equal(t, int64(0), sumQuota, "quota_data.quota 应在退款后净值归零") + assert.Equal(t, int64(0), sumTokens, "quota_data.token_used 应在退款后净值归零") +} + +// TestQuotaData_NegativeDeltaStillRecordsTokens 防回归用例: +// 当上游实际花费 < 预扣(部分退款),但仍然返回了 totalTokens 时, +// quota_data 必须做到「钱往负方向走,token 仍向正方向加到 totalTokens」。 +// 这是历史 sign bug 的高发场景:金额 delta 是负的,token 却必须是正的。 +func TestQuotaData_NegativeDeltaStillRecordsTokens(t *testing.T) { + truncate(t) + ctx := context.Background() + + prevDataExport := common.DataExportEnabled + prevLogConsume := common.LogConsumeEnabled + common.DataExportEnabled = true + common.LogConsumeEnabled = true + t.Cleanup(func() { + common.DataExportEnabled = prevDataExport + common.LogConsumeEnabled = prevLogConsume + }) + + const userID, tokenID, channelID = 42, 42, 42 + const walletAfterPre, preConsumed = 5000, 5000 + const actualQuota = 3000 // delta = -2000,部分退款 + const totalTokens = 1234 + + seedUserWithUsed(t, userID, walletAfterPre, preConsumed, 1) + seedToken(t, tokenID, userID, "sk-qdata-neg-tokens", 5000) + seedChannelWithUsed(t, channelID, int64(preConsumed)) + + username, err := model.GetUsernameById(userID, false) + require.NoError(t, err) + model.LogQuotaData(userID, username, "test-model", preConsumed, time.Now().Unix(), 0) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + RecalculateTaskQuota(ctx, task, actualQuota, totalTokens, "token重算-负delta") + + require.Eventually(t, func() bool { + model.CacheQuotaDataLock.Lock() + defer model.CacheQuotaDataLock.Unlock() + for _, qd := range model.CacheQuotaData { + if qd.UserID != userID || qd.ModelName != "test-model" { + continue + } + // quota: preConsumed + (actualQuota - preConsumed) = actualQuota + // token_used: 0 + totalTokens = totalTokens(关键:不是 -totalTokens) + return qd.Quota == actualQuota && qd.TokenUsed == totalTokens + } + return false + }, time.Second, 20*time.Millisecond) + + model.SaveQuotaDataCache() + + var sumQuota, sumTokens int64 + require.NoError(t, model.DB.Table("quota_data"). + Select("COALESCE(SUM(quota), 0)"). + Where("user_id = ? and model_name = ?", userID, "test-model"). + Scan(&sumQuota).Error) + require.NoError(t, model.DB.Table("quota_data"). + Select("COALESCE(SUM(token_used), 0)"). + Where("user_id = ? and model_name = ?", userID, "test-model"). + Scan(&sumTokens).Error) + assert.Equal(t, int64(actualQuota), sumQuota, "quota 应当落在 actualQuota") + assert.Equal(t, int64(totalTokens), sumTokens, "token_used 应当向正方向加到 totalTokens") + + // Log 表的 token 字段也应正确填到 CompletionTokens(即便日志类型是 Refund) + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeRefund, log.Type) + assert.Equal(t, 0, log.PromptTokens) + assert.Equal(t, totalTokens, log.CompletionTokens) +} + +// TestQuotaData_RecalcByTokensRecordsTokens 验证 token 重算路径下: +// 1. quota_data.quota 净值 = actualQuota(pre + delta) +// 2. quota_data.token_used 由 0 增加到 totalTokens(视频任务 input=0 全部记 completion) +// 3. Log 的 prompt_tokens=0 / completion_tokens=totalTokens +func TestQuotaData_RecalcByTokensRecordsTokens(t *testing.T) { + truncate(t) + ctx := context.Background() + + prevDataExport := common.DataExportEnabled + prevLogConsume := common.LogConsumeEnabled + common.DataExportEnabled = true + common.LogConsumeEnabled = true + t.Cleanup(func() { + common.DataExportEnabled = prevDataExport + common.LogConsumeEnabled = prevLogConsume + }) + + const userID, tokenID, channelID = 41, 41, 41 + const walletAfterPre, preConsumed = 8000, 2000 + const actualQuota = 3000 // delta = +1000 + const totalTokens = 1234 + + seedUserWithUsed(t, userID, walletAfterPre, preConsumed, 1) + seedToken(t, tokenID, userID, "sk-qdata-tokens", 5000) + seedChannelWithUsed(t, channelID, int64(preConsumed)) + + username, err := model.GetUsernameById(userID, false) + require.NoError(t, err) + model.LogQuotaData(userID, username, "test-model", preConsumed, time.Now().Unix(), 0) + + task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceWallet, 0) + RecalculateTaskQuota(ctx, task, actualQuota, totalTokens, "token重算") + + // 等异步 LogQuotaDataAdjust 刷进 cache + require.Eventually(t, func() bool { + model.CacheQuotaDataLock.Lock() + defer model.CacheQuotaDataLock.Unlock() + for _, qd := range model.CacheQuotaData { + if qd.UserID != userID || qd.ModelName != "test-model" { + continue + } + // cache 里此时 quota = preConsumed + (actualQuota - preConsumed) = actualQuota + // token_used = 0 + totalTokens = totalTokens + return qd.Quota == actualQuota && qd.TokenUsed == totalTokens + } + return false + }, time.Second, 20*time.Millisecond) + + model.SaveQuotaDataCache() + + var sumQuota, sumTokens int64 + require.NoError(t, model.DB.Table("quota_data"). + Select("COALESCE(SUM(quota), 0)"). + Where("user_id = ? and model_name = ?", userID, "test-model"). + Scan(&sumQuota).Error) + require.NoError(t, model.DB.Table("quota_data"). + Select("COALESCE(SUM(token_used), 0)"). + Where("user_id = ? and model_name = ?", userID, "test-model"). + Scan(&sumTokens).Error) + assert.Equal(t, int64(actualQuota), sumQuota) + assert.Equal(t, int64(totalTokens), sumTokens) + + // Log 的 token 字段被正确填写 + log := getLastLog(t) + require.NotNil(t, log) + assert.Equal(t, model.LogTypeConsume, log.Type) + assert.Equal(t, 0, log.PromptTokens) + assert.Equal(t, totalTokens, log.CompletionTokens) +} + func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { truncate(t) ctx := context.Background() @@ -414,7 +711,7 @@ func TestRecalculate_Subscription_NegativeDelta(t *testing.T) { task := makeTask(userID, channelID, preConsumed, tokenID, BillingSourceSubscription, subID) - RecalculateTaskQuota(ctx, task, actualQuota, "subscription over-charge") + RecalculateTaskQuota(ctx, task, actualQuota, 0, "subscription over-charge") // Subscription used should decrease by delta (refund 3000) assert.Equal(t, subUsed-int64(preConsumed-actualQuota), getSubscriptionUsed(t, subID)) @@ -476,7 +773,7 @@ func simulatePollBilling(ctx context.Context, task *model.Task, newStatus model. } if shouldSettle && actualQuota > 0 { - RecalculateTaskQuota(ctx, task, actualQuota, "test settle") + RecalculateTaskQuota(ctx, task, actualQuota, 0, "test settle") } if shouldRefund { RefundTaskQuota(ctx, task, task.FailReason) diff --git a/service/task_polling.go b/service/task_polling.go index dc85e579e8c..b747b6d4d55 100644 --- a/service/task_polling.go +++ b/service/task_polling.go @@ -548,7 +548,9 @@ func settleTaskBillingOnComplete(ctx context.Context, adaptor TaskPollingAdaptor } // 1. 优先让 adaptor 决定最终额度 if actualQuota := adaptor.AdjustBillingOnComplete(task, taskResult); actualQuota > 0 { - RecalculateTaskQuota(ctx, task, actualQuota, "adaptor计费调整") + // adaptor 路径若上游也给出了 totalTokens(例如 Doubao seedance 的 usage.total_tokens), + // 一并透传到统计,否则传 0 不影响。 + RecalculateTaskQuota(ctx, task, actualQuota, taskResult.TotalTokens, "adaptor计费调整") return } // 2. 回退到 token 重算