Skip to content

Commit 8a9ce9e

Browse files
sanwzzzslh-dev
authored andcommitted
feat: enrich /v1/models with context_length, pricing, and max_output_tokens
1 parent 78f691d commit 8a9ce9e

6 files changed

Lines changed: 6427 additions & 20 deletions

File tree

backend/cmd/server/wire_gen.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/internal/handler/gateway_handler.go

Lines changed: 109 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ type GatewayHandler struct {
5151
maxAccountSwitchesGemini int
5252
cfg *config.Config
5353
settingService *service.SettingService
54+
pricingService *service.PricingService
5455
}
5556

5657
// NewGatewayHandler creates a new GatewayHandler
@@ -68,6 +69,7 @@ func NewGatewayHandler(
6869
userMsgQueueService *service.UserMessageQueueService,
6970
cfg *config.Config,
7071
settingService *service.SettingService,
72+
pricingService *service.PricingService,
7173
) *GatewayHandler {
7274
pingInterval := time.Duration(0)
7375
maxAccountSwitches := 10
@@ -104,6 +106,7 @@ func NewGatewayHandler(
104106
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
105107
cfg: cfg,
106108
settingService: settingService,
109+
pricingService: pricingService,
107110
}
108111
}
109112

@@ -869,38 +872,125 @@ func (h *GatewayHandler) Models(c *gin.Context) {
869872
availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "")
870873

871874
if len(availableModels) > 0 {
872-
// Build model list from whitelist
873-
models := make([]claude.Model, 0, len(availableModels))
874-
for _, modelID := range availableModels {
875-
models = append(models, claude.Model{
876-
ID: modelID,
877-
Type: "model",
878-
DisplayName: modelID,
879-
CreatedAt: "2024-01-01T00:00:00Z",
880-
})
881-
}
882-
c.JSON(http.StatusOK, gin.H{
883-
"object": "list",
884-
"data": models,
885-
})
875+
h.writeModelsResponse(c, availableModels)
886876
return
887877
}
888878

889879
// Fallback to default models
890880
if platform == "openai" {
891-
c.JSON(http.StatusOK, gin.H{
892-
"object": "list",
893-
"data": openai.DefaultModels,
894-
})
881+
h.writeDefaultOpenAIModelsResponse(c)
895882
return
896883
}
897884

885+
h.writeDefaultClaudeModelsResponse(c)
886+
}
887+
888+
type modelPricingResponse struct {
889+
InputCostPerToken *float64 `json:"input_cost_per_token,omitempty"`
890+
OutputCostPerToken *float64 `json:"output_cost_per_token,omitempty"`
891+
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost,omitempty"`
892+
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost,omitempty"`
893+
}
894+
895+
type modelResponse struct {
896+
ID string `json:"id"`
897+
Type string `json:"type"`
898+
DisplayName string `json:"display_name"`
899+
CreatedAt string `json:"created_at"`
900+
ContextLength *int `json:"context_length,omitempty"`
901+
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
902+
Pricing *modelPricingResponse `json:"pricing,omitempty"`
903+
}
904+
905+
func (h *GatewayHandler) writeModelsResponse(c *gin.Context, modelIDs []string) {
906+
models := make([]modelResponse, 0, len(modelIDs))
907+
for _, modelID := range modelIDs {
908+
models = append(models, h.buildModelResponse(modelID, modelID, "2024-01-01T00:00:00Z"))
909+
}
910+
c.JSON(http.StatusOK, gin.H{
911+
"object": "list",
912+
"data": models,
913+
})
914+
}
915+
916+
func (h *GatewayHandler) writeDefaultOpenAIModelsResponse(c *gin.Context) {
917+
models := make([]modelResponse, 0, len(openai.DefaultModels))
918+
for _, model := range openai.DefaultModels {
919+
createdAt := time.Unix(model.Created, 0).UTC().Format(time.RFC3339)
920+
models = append(models, h.buildModelResponse(model.ID, model.DisplayName, createdAt))
921+
}
922+
c.JSON(http.StatusOK, gin.H{
923+
"object": "list",
924+
"data": models,
925+
})
926+
}
927+
928+
func (h *GatewayHandler) writeDefaultClaudeModelsResponse(c *gin.Context) {
929+
models := make([]modelResponse, 0, len(claude.DefaultModels))
930+
for _, model := range claude.DefaultModels {
931+
models = append(models, h.buildModelResponse(model.ID, model.DisplayName, model.CreatedAt))
932+
}
898933
c.JSON(http.StatusOK, gin.H{
899934
"object": "list",
900-
"data": claude.DefaultModels,
935+
"data": models,
901936
})
902937
}
903938

939+
func (h *GatewayHandler) buildModelResponse(modelID, displayName, createdAt string) modelResponse {
940+
resp := modelResponse{
941+
ID: modelID,
942+
Type: "model",
943+
DisplayName: displayName,
944+
CreatedAt: createdAt,
945+
}
946+
947+
pricing := h.lookupModelMetadata(modelID)
948+
if pricing == nil {
949+
return resp
950+
}
951+
952+
if pricing.MaxInputTokens > 0 {
953+
v := pricing.MaxInputTokens
954+
resp.ContextLength = &v
955+
}
956+
if pricing.MaxOutputTokens > 0 {
957+
v := pricing.MaxOutputTokens
958+
resp.MaxOutputTokens = &v
959+
}
960+
961+
metaPricing := &modelPricingResponse{}
962+
if pricing.InputCostPerToken > 0 {
963+
v := pricing.InputCostPerToken
964+
metaPricing.InputCostPerToken = &v
965+
}
966+
if pricing.OutputCostPerToken > 0 {
967+
v := pricing.OutputCostPerToken
968+
metaPricing.OutputCostPerToken = &v
969+
}
970+
if pricing.CacheReadInputTokenCost > 0 {
971+
v := pricing.CacheReadInputTokenCost
972+
metaPricing.CacheReadInputTokenCost = &v
973+
}
974+
if pricing.CacheCreationInputTokenCost > 0 {
975+
v := pricing.CacheCreationInputTokenCost
976+
metaPricing.CacheCreationInputTokenCost = &v
977+
}
978+
if metaPricing.InputCostPerToken != nil || metaPricing.OutputCostPerToken != nil || metaPricing.CacheReadInputTokenCost != nil || metaPricing.CacheCreationInputTokenCost != nil {
979+
resp.Pricing = metaPricing
980+
}
981+
982+
return resp
983+
}
984+
985+
func (h *GatewayHandler) lookupModelMetadata(modelID string) *service.LiteLLMModelPricing {
986+
if h != nil && h.pricingService != nil {
987+
if pricing := h.pricingService.GetModelPricing(modelID); pricing != nil {
988+
return pricing
989+
}
990+
}
991+
return service.GetDefaultModelMetadata(modelID)
992+
}
993+
904994
// AntigravityModels 返回 Antigravity 支持的全部模型
905995
// GET /antigravity/models
906996
func (h *GatewayHandler) AntigravityModels(c *gin.Context) {
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
//go:build unit
2+
3+
package handler
4+
5+
import (
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/gin-gonic/gin"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestBuildModelResponse_EnrichesOptionalMetadata(t *testing.T) {
15+
t.Parallel()
16+
17+
h := &GatewayHandler{}
18+
resp := h.buildModelResponse("gpt-5.4", "gpt-5.4", "2024-01-01T00:00:00Z")
19+
20+
require.Equal(t, "gpt-5.4", resp.ID)
21+
require.Equal(t, "model", resp.Type)
22+
require.Equal(t, "gpt-5.4", resp.DisplayName)
23+
require.Equal(t, "2024-01-01T00:00:00Z", resp.CreatedAt)
24+
require.NotNil(t, resp.ContextLength)
25+
require.Equal(t, 1050000, *resp.ContextLength)
26+
require.NotNil(t, resp.MaxOutputTokens)
27+
require.Equal(t, 128000, *resp.MaxOutputTokens)
28+
require.NotNil(t, resp.Pricing)
29+
require.NotNil(t, resp.Pricing.InputCostPerToken)
30+
require.InDelta(t, 2.5e-06, *resp.Pricing.InputCostPerToken, 1e-12)
31+
require.NotNil(t, resp.Pricing.OutputCostPerToken)
32+
require.InDelta(t, 1.5e-05, *resp.Pricing.OutputCostPerToken, 1e-12)
33+
}
34+
35+
func TestBuildModelResponse_OmitsUnknownOptionalMetadata(t *testing.T) {
36+
t.Parallel()
37+
38+
h := &GatewayHandler{}
39+
resp := h.buildModelResponse("custom-unknown-model", "custom-unknown-model", "2024-01-01T00:00:00Z")
40+
41+
require.Nil(t, resp.ContextLength)
42+
require.Nil(t, resp.MaxOutputTokens)
43+
require.Nil(t, resp.Pricing)
44+
}
45+
46+
func TestModels_WhitelistResponseIncludesMetadata(t *testing.T) {
47+
gin.SetMode(gin.TestMode)
48+
w := httptest.NewRecorder()
49+
c, _ := gin.CreateTestContext(w)
50+
c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil)
51+
52+
h := &GatewayHandler{}
53+
h.writeModelsResponse(c, []string{"gpt-5.4"})
54+
55+
require.Equal(t, http.StatusOK, w.Code)
56+
body := w.Body.String()
57+
require.Contains(t, body, `"id":"gpt-5.4"`)
58+
require.Contains(t, body, `"context_length":1050000`)
59+
require.Contains(t, body, `"max_output_tokens":128000`)
60+
require.Contains(t, body, `"pricing"`)
61+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package modelmetadata
2+
3+
import (
4+
_ "embed"
5+
"encoding/json"
6+
)
7+
8+
//go:embed model_prices_and_context_window.json
9+
var defaultModelMetadataJSON []byte
10+
11+
type LiteLLMRawEntry struct {
12+
InputCostPerToken *float64 `json:"input_cost_per_token"`
13+
OutputCostPerToken *float64 `json:"output_cost_per_token"`
14+
CacheReadInputTokenCost *float64 `json:"cache_read_input_token_cost"`
15+
CacheCreationInputTokenCost *float64 `json:"cache_creation_input_token_cost"`
16+
MaxInputTokens *int `json:"max_input_tokens"`
17+
MaxOutputTokens *int `json:"max_output_tokens"`
18+
MaxTokens *int `json:"max_tokens"`
19+
}
20+
21+
type LiteLLMModelMetadata struct {
22+
InputCostPerToken float64
23+
OutputCostPerToken float64
24+
CacheReadInputTokenCost float64
25+
CacheCreationInputTokenCost float64
26+
MaxInputTokens int
27+
MaxOutputTokens int
28+
MaxTokens int
29+
}
30+
31+
func GetDefaultModelMetadata(modelName string) *LiteLLMModelMetadata {
32+
var rawData map[string]json.RawMessage
33+
if err := json.Unmarshal(defaultModelMetadataJSON, &rawData); err != nil {
34+
return nil
35+
}
36+
entry, ok := rawData[modelName]
37+
if !ok {
38+
return nil
39+
}
40+
var raw LiteLLMRawEntry
41+
if err := json.Unmarshal(entry, &raw); err != nil {
42+
return nil
43+
}
44+
meta := &LiteLLMModelMetadata{}
45+
if raw.InputCostPerToken != nil {
46+
meta.InputCostPerToken = *raw.InputCostPerToken
47+
}
48+
if raw.OutputCostPerToken != nil {
49+
meta.OutputCostPerToken = *raw.OutputCostPerToken
50+
}
51+
if raw.CacheReadInputTokenCost != nil {
52+
meta.CacheReadInputTokenCost = *raw.CacheReadInputTokenCost
53+
}
54+
if raw.CacheCreationInputTokenCost != nil {
55+
meta.CacheCreationInputTokenCost = *raw.CacheCreationInputTokenCost
56+
}
57+
if raw.MaxInputTokens != nil {
58+
meta.MaxInputTokens = *raw.MaxInputTokens
59+
}
60+
if raw.MaxOutputTokens != nil {
61+
meta.MaxOutputTokens = *raw.MaxOutputTokens
62+
}
63+
if raw.MaxTokens != nil {
64+
meta.MaxTokens = *raw.MaxTokens
65+
}
66+
if meta.InputCostPerToken == 0 && meta.OutputCostPerToken == 0 && meta.CacheReadInputTokenCost == 0 && meta.CacheCreationInputTokenCost == 0 && meta.MaxInputTokens == 0 && meta.MaxOutputTokens == 0 && meta.MaxTokens == 0 {
67+
return nil
68+
}
69+
return meta
70+
}

0 commit comments

Comments
 (0)