diff --git a/database/billing.go b/database/billing.go index 826a9329..a1c80e2f 100644 --- a/database/billing.go +++ b/database/billing.go @@ -41,12 +41,12 @@ var ( modelPricingRules = []modelPricingRule{ {model: "gpt-5.5", pricing: ModelPricing{ - InputPricePerMToken: 5.0, - InputPricePerMTokenPriority: 12.5, - OutputPricePerMToken: 30.0, - OutputPricePerMTokenPriority: 75.0, - CacheReadPricePerMToken: 0.5, - CacheReadPricePerMTokenPriority: 1.25, + InputPricePerMToken: 5.0, + InputPricePerMTokenPriority: 12.5, + OutputPricePerMToken: 30.0, + OutputPricePerMTokenPriority: 75.0, + CacheReadPricePerMToken: 0.5, + CacheReadPricePerMTokenPriority: 1.25, LongInputPricePerMToken: 10.0, LongInputPricePerMTokenPriority: 25.0, LongOutputPricePerMToken: 45.0, @@ -55,10 +55,10 @@ var ( LongCacheReadPricePerMTokenPriority: 2.5, }}, {model: "gpt-5.5-pro", pricing: ModelPricing{ - InputPricePerMToken: 30.0, - InputPricePerMTokenPriority: 75.0, - OutputPricePerMToken: 180.0, - OutputPricePerMTokenPriority: 450.0, + InputPricePerMToken: 30.0, + InputPricePerMTokenPriority: 75.0, + OutputPricePerMToken: 180.0, + OutputPricePerMTokenPriority: 450.0, LongInputPricePerMToken: 60.0, LongInputPricePerMTokenPriority: 150.0, LongOutputPricePerMToken: 270.0, @@ -67,12 +67,12 @@ var ( {model: "gpt-5.4-mini", pricing: ModelPricing{InputPricePerMToken: 0.75, OutputPricePerMToken: 4.5, CacheReadPricePerMToken: 0.075}}, {model: "gpt-5.4-nano", pricing: ModelPricing{InputPricePerMToken: 0.2, OutputPricePerMToken: 1.25, CacheReadPricePerMToken: 0.02}}, {model: "gpt-5.4", pricing: ModelPricing{ - InputPricePerMToken: 2.5, - InputPricePerMTokenPriority: 5.0, - OutputPricePerMToken: 15.0, - OutputPricePerMTokenPriority: 30.0, - CacheReadPricePerMToken: 0.25, - CacheReadPricePerMTokenPriority: 0.5, + InputPricePerMToken: 2.5, + InputPricePerMTokenPriority: 5.0, + OutputPricePerMToken: 15.0, + OutputPricePerMTokenPriority: 30.0, + CacheReadPricePerMToken: 0.25, + CacheReadPricePerMTokenPriority: 0.5, LongInputPricePerMToken: 5.0, LongInputPricePerMTokenPriority: 10.0, LongOutputPricePerMToken: 22.5, @@ -81,10 +81,10 @@ var ( LongCacheReadPricePerMTokenPriority: 1.0, }}, {model: "gpt-5.4-pro", pricing: ModelPricing{ - InputPricePerMToken: 30.0, - InputPricePerMTokenPriority: 75.0, - OutputPricePerMToken: 180.0, - OutputPricePerMTokenPriority: 450.0, + InputPricePerMToken: 30.0, + InputPricePerMTokenPriority: 75.0, + OutputPricePerMToken: 180.0, + OutputPricePerMTokenPriority: 450.0, LongInputPricePerMToken: 60.0, LongInputPricePerMTokenPriority: 150.0, LongOutputPricePerMToken: 270.0, @@ -326,7 +326,8 @@ func geminiFamilyPricing(model string) *ModelPricing { } func usePriorityPricing(serviceTier string, pricing *ModelPricing) bool { - if normalizeServiceTier(serviceTier) != "priority" { + tier := normalizeServiceTier(serviceTier) + if tier != "priority" && tier != "fast" { return false } return pricing.InputPricePerMTokenPriority > 0 || @@ -336,8 +337,6 @@ func usePriorityPricing(serviceTier string, pricing *ModelPricing) bool { func serviceTierCostMultiplier(serviceTier string) float64 { switch normalizeServiceTier(serviceTier) { - case "priority": - return 2.0 case "flex": return 0.5 default: diff --git a/database/billing_test.go b/database/billing_test.go index e7d65396..9b41f84c 100644 --- a/database/billing_test.go +++ b/database/billing_test.go @@ -100,6 +100,33 @@ func TestCalculateCostHandlesCachedTokensAndServiceTier(t *testing.T) { cachedTokens: 200, want: 0.0191, }, + { + name: "uses priority prices for fast tier", + model: "gpt-5.4", + serviceTier: "fast", + inputTokens: 1000, + outputTokens: 500, + cachedTokens: 200, + want: 0.0191, + }, + { + name: "does not invent priority multiplier when priority price is unknown", + model: "gpt-4o", + serviceTier: "priority", + inputTokens: 1000, + outputTokens: 500, + cachedTokens: 200, + want: 0.0075, + }, + { + name: "fast tier falls back to standard pricing when priority price is unknown", + model: "gpt-4o", + serviceTier: "fast", + inputTokens: 1000, + outputTokens: 500, + cachedTokens: 200, + want: 0.0075, + }, { name: "applies flex multiplier", model: "gpt-5.4", @@ -277,7 +304,7 @@ func TestCodexAutoReviewModelNormalizesToGPT54(t *testing.T) { func TestCodexAutoReviewLongContextPricing(t *testing.T) { // codex-auto-review maps to gpt-5.4 which has long context pricing. long := CalculateCostBreakdown(300000, 500, 100, "codex-auto-review", "") - assertFloatEqual(t, long.InputPricePerMToken, 5.0) // long input price + assertFloatEqual(t, long.InputPricePerMToken, 5.0) // long input price assertFloatEqual(t, long.OutputPricePerMToken, 22.5) // long output price assertFloatEqual(t, long.CacheReadPricePerMToken, 0.5) // long cache read price } diff --git a/database/postgres.go b/database/postgres.go index dc56414f..2b8c8567 100644 --- a/database/postgres.go +++ b/database/postgres.go @@ -1543,8 +1543,13 @@ func (db *DB) InsertUsageLog(ctx context.Context, log *UsageLogInput) error { billingModel = log.Model } - // 计算账号计费金额(标准费用) - accountBilled := calculateCost(log.InputTokens, log.OutputTokens, log.CachedTokens, billingModel, log.ServiceTier) + billingServiceTier := log.BillingServiceTier + if billingServiceTier == "" { + billingServiceTier = log.ServiceTier + } + + // 计算账号计费金额(基于上游实际 service tier) + accountBilled := calculateCost(log.InputTokens, log.OutputTokens, log.CachedTokens, billingModel, billingServiceTier) // 用户计费金额与账号计费金额相同(简化版,未来可支持倍率) userBilled := accountBilled @@ -1598,38 +1603,39 @@ func (db *DB) InsertUsageLog(ctx context.Context, log *UsageLogInput) error { // UsageLogInput 日志写入参数 type UsageLogInput struct { - AccountID int64 - Endpoint string - Model string - EffectiveModel string - PromptTokens int - CompletionTokens int - TotalTokens int - StatusCode int - DurationMs int - InputTokens int - OutputTokens int - ReasoningTokens int - FirstTokenMs int - ReasoningEffort string - InboundEndpoint string - UpstreamEndpoint string - Stream bool - CachedTokens int - ServiceTier string - APIKeyID int64 - APIKeyName string - APIKeyMasked string - ImageCount int - ImageWidth int - ImageHeight int - ImageBytes int - ImageFormat string - ImageSize string - IsRetryAttempt bool - AttemptIndex int - UpstreamErrorKind string - ErrorMessage string + AccountID int64 + Endpoint string + Model string + EffectiveModel string + PromptTokens int + CompletionTokens int + TotalTokens int + StatusCode int + DurationMs int + InputTokens int + OutputTokens int + ReasoningTokens int + FirstTokenMs int + ReasoningEffort string + InboundEndpoint string + UpstreamEndpoint string + Stream bool + CachedTokens int + ServiceTier string + BillingServiceTier string + APIKeyID int64 + APIKeyName string + APIKeyMasked string + ImageCount int + ImageWidth int + ImageHeight int + ImageBytes int + ImageFormat string + ImageSize string + IsRetryAttempt bool + AttemptIndex int + UpstreamErrorKind string + ErrorMessage string } func (l *UsageLog) populateBillingBreakdown() { diff --git a/database/sqlite_test.go b/database/sqlite_test.go index 46045fbf..c637035b 100644 --- a/database/sqlite_test.go +++ b/database/sqlite_test.go @@ -737,6 +737,74 @@ func TestUsageLogsReturnBillingFields(t *testing.T) { } } +func TestUsageLogsBillFastByActualServiceTier(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "codex2api.db") + + db, err := New("sqlite", dbPath) + if err != nil { + t.Fatalf("New(sqlite) 返回错误: %v", err) + } + defer db.Close() + + ctx := context.Background() + if err := db.InsertUsageLog(ctx, &UsageLogInput{ + AccountID: 1, + Endpoint: "/v1/responses", + Model: "gpt-5.4", + StatusCode: 200, + InputTokens: 1000, + OutputTokens: 500, + CachedTokens: 200, + ServiceTier: "fast", + BillingServiceTier: "default", + }); err != nil { + t.Fatalf("InsertUsageLog 返回错误: %v", err) + } + if err := db.InsertUsageLog(ctx, &UsageLogInput{ + AccountID: 1, + Endpoint: "/v1/responses", + Model: "gpt-5.4", + StatusCode: 200, + InputTokens: 1000, + OutputTokens: 500, + CachedTokens: 200, + ServiceTier: "fast", + BillingServiceTier: "priority", + }); err != nil { + t.Fatalf("InsertUsageLog 返回错误: %v", err) + } + db.flushLogs() + + logs, err := db.ListRecentUsageLogs(ctx, 10) + if err != nil { + t.Fatalf("ListRecentUsageLogs 返回错误: %v", err) + } + if len(logs) != 2 { + t.Fatalf("len(logs) = %d, want 2", len(logs)) + } + + wantPriority := calculateCost(1000, 500, 200, "gpt-5.4", "priority") + wantDefault := calculateCost(1000, 500, 200, "gpt-5.4", "default") + seenPriority := false + seenDefault := false + for _, log := range logs { + if log.ServiceTier != "fast" { + t.Fatalf("log tier = %q, want fast", log.ServiceTier) + } + switch log.AccountBilled { + case wantPriority: + seenPriority = true + case wantDefault: + seenDefault = true + default: + t.Fatalf("unexpected billed amount %.12f, want %.12f or %.12f", log.AccountBilled, wantPriority, wantDefault) + } + } + if !seenPriority || !seenDefault { + t.Fatalf("billing tiers seen priority=%v default=%v, want both", seenPriority, seenDefault) + } +} + func TestUsageLogsReturnErrorMessage(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "codex2api.db") diff --git a/proxy/handler.go b/proxy/handler.go index 984e45e9..a683dea1 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -1448,19 +1448,21 @@ func (h *Handler) Responses(c *gin.Context) { } resolvedServiceTier := resolveServiceTier(actualServiceTier, serviceTier) + billingServiceTier := resolveBillingServiceTier(actualServiceTier, serviceTier) c.Set("x-service-tier", resolvedServiceTier) logInput := &database.UsageLogInput{ - AccountID: account.ID(), - Endpoint: "/v1/responses", - Model: model, - StatusCode: outcome.logStatusCode, - DurationMs: totalDuration, - FirstTokenMs: firstTokenMs, - ReasoningEffort: reasoningEffort, - InboundEndpoint: "/v1/responses", - UpstreamEndpoint: upstreamEndpoint, - Stream: isStream, - ServiceTier: resolvedServiceTier, + AccountID: account.ID(), + Endpoint: "/v1/responses", + Model: model, + StatusCode: outcome.logStatusCode, + DurationMs: totalDuration, + FirstTokenMs: firstTokenMs, + ReasoningEffort: reasoningEffort, + InboundEndpoint: "/v1/responses", + UpstreamEndpoint: upstreamEndpoint, + Stream: isStream, + ServiceTier: resolvedServiceTier, + BillingServiceTier: billingServiceTier, } if outcome.logStatusCode != http.StatusOK { logInput.ErrorMessage = usageLogErrorMessage(outcome.logStatusCode, []byte(outcome.failureMessage)) @@ -1780,20 +1782,22 @@ func (h *Handler) Responses(c *gin.Context) { } resolvedServiceTier := resolveServiceTier(actualServiceTier, serviceTier) + billingServiceTier := resolveBillingServiceTier(actualServiceTier, serviceTier) c.Set("x-service-tier", resolvedServiceTier) logInput := &database.UsageLogInput{ - AccountID: account.ID(), - Endpoint: "/v1/responses", - Model: model, - StatusCode: logStatusCode, - DurationMs: totalDuration, - FirstTokenMs: firstTokenMs, - ReasoningEffort: reasoningEffort, - InboundEndpoint: "/v1/responses", - UpstreamEndpoint: "/v1/responses", - Stream: isStream, - ServiceTier: resolvedServiceTier, + AccountID: account.ID(), + Endpoint: "/v1/responses", + Model: model, + StatusCode: logStatusCode, + DurationMs: totalDuration, + FirstTokenMs: firstTokenMs, + ReasoningEffort: reasoningEffort, + InboundEndpoint: "/v1/responses", + UpstreamEndpoint: "/v1/responses", + Stream: isStream, + ServiceTier: resolvedServiceTier, + BillingServiceTier: billingServiceTier, } if logStatusCode != http.StatusOK { logInput.ErrorMessage = usageLogErrorMessage(logStatusCode, []byte(outcome.failureMessage)) @@ -2035,25 +2039,27 @@ func (h *Handler) ResponsesCompact(c *gin.Context) { actualServiceTier := gjson.GetBytes(respBody, "service_tier").String() resolvedServiceTier := resolveServiceTier(actualServiceTier, serviceTier) + billingServiceTier := resolveBillingServiceTier(actualServiceTier, serviceTier) totalDuration := int(time.Since(start).Milliseconds()) h.logUsageForRequest(c, &database.UsageLogInput{ - AccountID: account.ID(), - Endpoint: "/v1/responses/compact", - Model: model, - StatusCode: http.StatusOK, - DurationMs: totalDuration, - PromptTokens: promptTokens, - CompletionTokens: completionTokens, - TotalTokens: totalTokens, - InputTokens: promptTokens, - OutputTokens: completionTokens, - ReasoningTokens: reasoningTokens, - CachedTokens: cachedTokens, - ReasoningEffort: reasoningEffort, - InboundEndpoint: "/v1/responses/compact", - UpstreamEndpoint: "/v1/responses/compact", - ServiceTier: resolvedServiceTier, + AccountID: account.ID(), + Endpoint: "/v1/responses/compact", + Model: model, + StatusCode: http.StatusOK, + DurationMs: totalDuration, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: totalTokens, + InputTokens: promptTokens, + OutputTokens: completionTokens, + ReasoningTokens: reasoningTokens, + CachedTokens: cachedTokens, + ReasoningEffort: reasoningEffort, + InboundEndpoint: "/v1/responses/compact", + UpstreamEndpoint: "/v1/responses/compact", + ServiceTier: resolvedServiceTier, + BillingServiceTier: billingServiceTier, }) h.store.Release(account) @@ -2445,20 +2451,22 @@ func (h *Handler) ChatCompletions(c *gin.Context) { } resolvedServiceTier := resolveServiceTier(actualServiceTier, serviceTier) + billingServiceTier := resolveBillingServiceTier(actualServiceTier, serviceTier) c.Set("x-service-tier", resolvedServiceTier) logInput := &database.UsageLogInput{ - AccountID: account.ID(), - Endpoint: "/v1/chat/completions", - Model: model, - StatusCode: logStatusCode, - DurationMs: totalDuration, - FirstTokenMs: firstTokenMs, - ReasoningEffort: reasoningEffort, - InboundEndpoint: "/v1/chat/completions", - UpstreamEndpoint: "/v1/responses", - Stream: isStream, - ServiceTier: resolvedServiceTier, + AccountID: account.ID(), + Endpoint: "/v1/chat/completions", + Model: model, + StatusCode: logStatusCode, + DurationMs: totalDuration, + FirstTokenMs: firstTokenMs, + ReasoningEffort: reasoningEffort, + InboundEndpoint: "/v1/chat/completions", + UpstreamEndpoint: "/v1/responses", + Stream: isStream, + ServiceTier: resolvedServiceTier, + BillingServiceTier: billingServiceTier, } if logStatusCode != http.StatusOK { logInput.ErrorMessage = usageLogErrorMessage(logStatusCode, []byte(outcome.failureMessage)) diff --git a/proxy/translator.go b/proxy/translator.go index 9960e4e0..dc1f7363 100644 --- a/proxy/translator.go +++ b/proxy/translator.go @@ -1801,6 +1801,28 @@ func resolveServiceTier(actualTier, requestedTier string) string { return final } +// resolveBillingServiceTier keeps UI tier normalization separate from billing: +// fast/priority intent is billed as priority only when the upstream does not +// report a concrete tier, or when it confirms fast/priority. Any concrete +// upstream tier wins so billing follows the actual tier reported by upstream. +func resolveBillingServiceTier(actualTier, requestedTier string) string { + actualTier = strings.ToLower(strings.TrimSpace(actualTier)) + if actualTier != "" { + if actualTier == "priority" || actualTier == "fast" { + return "priority" + } + return actualTier + } + + requestedTier = strings.ToLower(strings.TrimSpace(requestedTier)) + switch requestedTier { + case "priority", "fast": + return "priority" + default: + return requestedTier + } +} + // 上游不支持的 JSON Schema 验证约束关键字 var unsupportedSchemaKeys = map[string]bool{ "uniqueItems": true, diff --git a/proxy/translator_test.go b/proxy/translator_test.go index a5f1b542..80ff7c5b 100644 --- a/proxy/translator_test.go +++ b/proxy/translator_test.go @@ -57,6 +57,30 @@ func TestResolveServiceTier(t *testing.T) { } } +func TestResolveBillingServiceTier(t *testing.T) { + tests := []struct { + name string + actual string + requested string + want string + }{ + {name: "actual priority wins", actual: "priority", requested: "fast", want: "priority"}, + {name: "actual default downgrade wins", actual: "default", requested: "fast", want: "default"}, + {name: "unknown concrete actual tier wins", actual: "burst", requested: "fast", want: "burst"}, + {name: "requested fast fallback bills priority", actual: "", requested: "fast", want: "priority"}, + {name: "requested priority fallback bills priority", actual: "", requested: "priority", want: "priority"}, + {name: "default stays default", actual: "default", requested: "", want: "default"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := resolveBillingServiceTier(tt.actual, tt.requested); got != tt.want { + t.Fatalf("resolveBillingServiceTier(%q, %q) = %q, want %q", tt.actual, tt.requested, got, tt.want) + } + }) + } +} + func TestSanitizeServiceTierForUpstream_FastToPriority(t *testing.T) { raw := []byte(`{ "model":"gpt-5.4",