diff --git a/aigateway/component/openai.go b/aigateway/component/openai.go index 8093942a..61ed2252 100644 --- a/aigateway/component/openai.go +++ b/aigateway/component/openai.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "log/slog" + "slices" "strconv" "strings" "time" @@ -67,6 +68,8 @@ func (m *openaiComponentImpl) GetAvailableModels(c context.Context, userName str externalModels := m.getExternalModels(c) models = append(models, externalModels...) + models = m.enrichModelsWithPrice(c, models) + // Save models to cache asynchronously go func(modelList []types.Model) { if len(modelList) == 0 { @@ -103,56 +106,72 @@ func (m *openaiComponentImpl) ListModels(c context.Context, userName string, req return filterAndPaginateModels(models, req), nil } -func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) types.ModelList { - // Apply fuzzy search filter if model_id is provided - searchQuery := req.ModelID - if searchQuery != "" { - filtered := make([]types.Model, 0, len(models)) - sq := strings.ToLower(searchQuery) - for _, model := range models { - if strings.Contains(strings.ToLower(model.ID), sq) { - filtered = append(filtered, model) - } - } - models = filtered +type modelFilter func(m *types.Model) bool + +func filterByModelID(query string) modelFilter { + return func(m *types.Model) bool { + return strings.Contains(strings.ToLower(m.ID), query) } +} - // Apply public filter if provided and parseable - if req.Public != "" { - if isPublic, err := strconv.ParseBool(req.Public); err == nil { - filtered := make([]types.Model, 0, len(models)) - for _, model := range models { - if model.Public == isPublic { - filtered = append(filtered, model) - } - } - models = filtered +func filterBySource(source string) modelFilter { + return func(m *types.Model) bool { + switch source { + case string(types.ModelSourceCSGHub): + return m.CSGHubModelID != "" + case string(types.ModelSourceExternal): + return m.Provider != "" + default: + return true } } +} - // Apply source filter if provided - if req.Source != "" { - source := strings.ToLower(req.Source) - filtered := make([]types.Model, 0, len(models)) - for _, model := range models { - switch source { - case string(types.ModelSourceCSGHub): - if model.CSGHubModelID != "" { - filtered = append(filtered, model) - } - case string(types.ModelSourceExternal): - if model.Provider != "" { - filtered = append(filtered, model) - } - default: - // Unknown source value, include all - filtered = append(filtered, model) +func filterByTask(task string) modelFilter { + return func(m *types.Model) bool { + modelTasks := strings.FieldsFunc(strings.ToLower(m.Task), func(r rune) bool { + return r == ',' + }) + return slices.Contains(modelTasks, task) + } +} + +func applyFilters(models []types.Model, filters []modelFilter) []types.Model { + if len(filters) == 0 { + return models + } + filtered := make([]types.Model, 0, len(models)) + for i := range models { + m := &models[i] + keep := true + for _, f := range filters { + if !f(m) { + keep = false + break } } - models = filtered + if keep { + filtered = append(filtered, *m) + } } + return filtered +} + +func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) types.ModelList { + var filters []modelFilter + + if searchQuery := strings.ToLower(req.ModelID); searchQuery != "" { + filters = append(filters, filterByModelID(searchQuery)) + } + if source := strings.ToLower(req.Source); source != "" { + filters = append(filters, filterBySource(source)) + } + if task := strings.ToLower(req.Task); task != "" { + filters = append(filters, filterByTask(task)) + } + + models = applyFilters(models, filters) - // Parse pagination parameters (defaults match previous handler behavior) per := 20 page := 1 if req.Per != "" { @@ -170,8 +189,7 @@ func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) type } totalCount := len(models) - offset := (page - 1) * per - startIndex := offset + startIndex := (page - 1) * per if startIndex > totalCount { startIndex = totalCount } @@ -187,18 +205,29 @@ func filterAndPaginateModels(models []types.Model, req types.ListModelsReq) type firstID = &paginated[0].ID lastID = &paginated[len(paginated)-1].ID } - hasMore := endIndex < totalCount return types.ModelList{ Object: "list", Data: paginated, FirstID: firstID, LastID: lastID, - HasMore: hasMore, + HasMore: endIndex < totalCount, TotalCount: totalCount, } } +// providerTypeFromDeployType maps a deploy type integer to the LLM type string (MetaKeyLLMType). +func providerTypeFromDeployType(t int) string { + switch t { + case commontypes.ServerlessType: + return types.ProviderTypeServerless + case commontypes.InferenceType: + return types.ProviderTypeInference + default: + return types.ProviderTypeInference + } +} + func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ([]types.Model, error) { runningDeploys, err := m.deployStore.RunningVisibleToUser(c, userID) if err != nil { @@ -212,11 +241,6 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ( } // Check if engine_args contains tool-call-parser parameter supportFunctionCall := strings.Contains(deploy.EngineArgs, "tool-call-parser") - // Determine public/private based on deployment type, ownership and secure level. - isPublic := true - if deploy.Type == commontypes.InferenceType && deploy.SecureLevel == commontypes.EndpointPrivate && deploy.UserID == userID { - isPublic = false // private - user's own deployment with private secure level - } repoName := deploy.Repository.Name m := types.Model{ BaseModel: types.BaseModel{ @@ -225,7 +249,9 @@ func (m *openaiComponentImpl) getCSGHubModels(c context.Context, userID int64) ( SupportFunctionCall: supportFunctionCall, Task: string(deploy.Task), DisplayName: repoName, - Public: isPublic, + Metadata: map[string]any{ + types.MetaKeyLLMType: providerTypeFromDeployType(deploy.Type), + }, }, InternalModelInfo: types.InternalModelInfo{ CSGHubModelID: deploy.Repository.Path, @@ -266,6 +292,8 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model search := &commontypes.SearchLLMConfig{} searchType := 16 search.Type = &searchType + enabled := true + search.Enabled = &enabled per := 50 page := 1 @@ -278,15 +306,31 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model } for _, extModel := range extModels { + // Extract tasks from metadata if present + task := "" + if extModel.Metadata != nil { + if tasks, ok := extModel.Metadata[types.MetaKeyTasks].([]any); ok && len(tasks) > 0 { + tasksStrings := make([]string, 0, len(tasks)) + for _, t := range tasks { + if s, ok := t.(string); ok { + tasksStrings = append(tasksStrings, s) + } + } + task = strings.Join(tasksStrings, ",") + } + } + if extModel.Metadata == nil { + extModel.Metadata = map[string]any{} + } + extModel.Metadata[types.MetaKeyLLMType] = types.ProviderTypeExternalLLM m := types.Model{ BaseModel: types.BaseModel{ Object: "model", ID: extModel.ModelName, OwnedBy: extModel.Provider, DisplayName: extModel.DisplayName, - // Metadata is allowed to be nil; JSON will contain `null` for nil maps. - Metadata: extModel.Metadata, - Public: true, // external models are always public + Metadata: extModel.Metadata, + Task: task, }, Endpoint: extModel.ApiEndpoint, ExternalModelInfo: types.ExternalModelInfo{ diff --git a/aigateway/component/openai_ce.go b/aigateway/component/openai_ce.go index b3df1dfe..160fc70f 100644 --- a/aigateway/component/openai_ce.go +++ b/aigateway/component/openai_ce.go @@ -48,3 +48,7 @@ func parseScene(sceneValue string) common_types.SceneType { func (e *extendOpenai) CheckBalance(ctx context.Context, username, userUUID string) error { return nil } + +func (e *extendOpenai) enrichModelsWithPrice(_ context.Context, models []types.Model) []types.Model { + return models +} diff --git a/aigateway/component/openai_ce_test.go b/aigateway/component/openai_ce_test.go index 229f2e3e..0bfcf64c 100644 --- a/aigateway/component/openai_ce_test.go +++ b/aigateway/component/openai_ce_test.go @@ -87,6 +87,8 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { mockLLMConfigStore.EXPECT().Index(mock.Anything, 50, 1, mock.Anything). Return([]*database.LLMConfig{}, 0, nil) + // Must match JSON produced by saveModelsToCache (getCSGHubModels + ForInternalUse): + // DisplayName is Repository.Name; NeedSensitiveCheck is unset (false) on CSGHub models. expectModels := []types.Model{ { BaseModel: types.BaseModel{ @@ -95,8 +97,10 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { Object: "model", Created: deploys[0].CreatedAt.Unix(), Task: "text-generation", - DisplayName: "model1", - Public: true, + DisplayName: deploys[0].Repository.Name, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", InternalModelInfo: types.InternalModelInfo{ @@ -111,12 +115,15 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { }, { BaseModel: types.BaseModel{ - ID: "hf-model2:svc2", - OwnedBy: "OpenCSG", - Object: "model", - Created: deploys[1].CreatedAt.Unix(), - Task: "text-to-image", - Public: true, + ID: "hf-model2:svc2", + OwnedBy: "OpenCSG", + Object: "model", + Created: deploys[1].CreatedAt.Unix(), + Task: "text-to-image", + DisplayName: deploys[1].Repository.Name, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeServerless, + }, }, Endpoint: "endpoint2", InternalModelInfo: types.InternalModelInfo{ @@ -147,10 +154,8 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { require.Len(t, models, 2) assert.Equal(t, "model1:svc1", models[0].ID) assert.Equal(t, "publicuser", models[0].OwnedBy) - assert.True(t, models[0].Public) assert.Equal(t, "hf-model2:svc2", models[1].ID) assert.Equal(t, "OpenCSG", models[1].OwnedBy) - assert.True(t, models[1].Public) wg.Wait() }) @@ -211,31 +216,42 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { Object: "model", Created: deploys[0].CreatedAt.Unix(), Task: "text-generation", - DisplayName: "model1", - Public: true, + DisplayName: deploys[0].Repository.Name, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[0].ClusterID, - SvcName: deploys[0].SvcName, - ImageID: deploys[0].ImageID, + CSGHubModelID: deploys[0].Repository.Path, + OwnerUUID: deploys[0].User.UUID, + ClusterID: deploys[0].ClusterID, + SvcName: deploys[0].SvcName, + SvcType: deploys[0].Type, + ImageID: deploys[0].ImageID, }, InternalUse: true, }, { BaseModel: types.BaseModel{ - ID: "hf-model2:svc2", - OwnedBy: "OpenCSG", - Object: "model", - Created: deploys[1].CreatedAt.Unix(), - Task: "text-to-image", - Public: true, + ID: "hf-model2:svc2", + OwnedBy: "OpenCSG", + Object: "model", + Created: deploys[1].CreatedAt.Unix(), + Task: "text-to-image", + DisplayName: deploys[1].Repository.Name, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeServerless, + }, }, Endpoint: "endpoint2", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[1].ClusterID, - SvcName: deploys[1].SvcName, - ImageID: deploys[1].ImageID, + CSGHubModelID: deploys[1].Repository.Path, + OwnerUUID: deploys[1].User.UUID, + ClusterID: deploys[1].ClusterID, + SvcName: deploys[1].SvcName, + SvcType: deploys[1].Type, + ImageID: deploys[1].ImageID, }, InternalUse: true, }, @@ -314,14 +330,19 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { Object: "model", Created: deploys[0].CreatedAt.Unix(), Task: "text-generation", - DisplayName: "model3", - Public: false, + DisplayName: deploys[0].Repository.Name, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint3", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[0].ClusterID, - SvcName: deploys[0].SvcName, - ImageID: deploys[0].ImageID, + CSGHubModelID: deploys[0].Repository.Path, + OwnerUUID: deploys[0].User.UUID, + ClusterID: deploys[0].ClusterID, + SvcName: deploys[0].SvcName, + SvcType: deploys[0].Type, + ImageID: deploys[0].ImageID, }, InternalUse: true, }, @@ -344,7 +365,6 @@ func TestOpenAIComponent_GetAvailableModels(t *testing.T) { assert.NoError(t, err) assert.Len(t, models, 1) assert.Equal(t, "model3:svc3", models[0].ID) - assert.False(t, models[0].Public) wg.Wait() }) @@ -400,14 +420,20 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { OwnedBy: "testuser", Object: "model", Created: deploys[0].CreatedAt.Unix(), - DisplayName: "model1", - Public: true, + Task: string(deploys[0].Task), + DisplayName: deploys[0].Repository.Name, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", InternalModelInfo: types.InternalModelInfo{ - ClusterID: deploys[0].ClusterID, - SvcName: deploys[0].SvcName, - ImageID: deploys[0].ImageID, + CSGHubModelID: deploys[0].Repository.Path, + OwnerUUID: deploys[0].User.UUID, + ClusterID: deploys[0].ClusterID, + SvcName: deploys[0].SvcName, + SvcType: deploys[0].Type, + ImageID: deploys[0].ImageID, }, InternalUse: true, }, @@ -480,13 +506,14 @@ func TestOpenAIComponent_GetModelByID(t *testing.T) { deploys[0].CreatedAt = now expectModel := types.Model{ BaseModel: types.BaseModel{ - ID: "model1:svc1", - OwnedBy: "testuser", - Object: "model", - Created: deploys[0].CreatedAt.Unix(), - Task: "text-generation", - DisplayName: "model1", - Public: true, + ID: "model1:svc1", + OwnedBy: "testuser", + Object: "model", + Created: deploys[0].CreatedAt.Unix(), + Task: "text-generation", + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeInference, + }, }, Endpoint: "endpoint1", } @@ -514,8 +541,10 @@ func TestOpenAIComponent_ExtGetAvailableModels_Error(t *testing.T) { modelListCache: mockCache, } searchType := 16 + enabled := true search := &commontypes.SearchLLMConfig{ - Type: &searchType, + Type: &searchType, + Enabled: &enabled, } mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search). Return(nil, 0, errors.New("test error")).Once() @@ -563,7 +592,9 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { ID: "test-model-1", OwnedBy: "OpenAI", Object: "model", - Public: true, + Metadata: map[string]any{ + types.MetaKeyLLMType: types.ProviderTypeExternalLLM, + }, }, Endpoint: "http://test-endpoint-1.com", ExternalModelInfo: types.ExternalModelInfo{ @@ -584,8 +615,10 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { mockDeployStore.EXPECT().RunningVisibleToUser(mock.Anything, user.ID). Return([]database.Deploy{}, nil) searchType := 16 + enabled := true search := &commontypes.SearchLLMConfig{ - Type: &searchType, + Type: &searchType, + Enabled: &enabled, } mockLLMConfigStore.EXPECT().Index(ctx, 50, 1, search).Return(mockModels, 1, nil) mockCache.EXPECT().HSet(mock.Anything, modelCacheKey, "test-model-1", string(expectJson)). @@ -604,39 +637,3 @@ func TestOpenAIComponent_ExtGetAvailableModels_SinglePage(t *testing.T) { require.Equal(t, "test-model-1", models[0].ID) wg.Wait() } - -func TestParseScene(t *testing.T) { - tests := []struct { - name string - sceneValue string - expected commontypes.SceneType - }{ - { - name: "any scene value returns SceneModelServerless", - sceneValue: commontypes.SceneHeaderCSGHub, - expected: commontypes.SceneModelServerless, - }, - { - name: "empty scene returns SceneModelServerless", - sceneValue: "", - expected: commontypes.SceneModelServerless, - }, - { - name: "agentichub scene returns SceneModelServerless", - sceneValue: commontypes.SceneHeaderAgenticHub, - expected: commontypes.SceneModelServerless, - }, - { - name: "unknown scene returns SceneModelServerless", - sceneValue: "unknown", - expected: commontypes.SceneModelServerless, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := parseScene(tt.sceneValue) - assert.Equal(t, tt.expected, result) - }) - } -} diff --git a/aigateway/component/openai_test.go b/aigateway/component/openai_test.go index 413b35e6..5751cbdc 100644 --- a/aigateway/component/openai_test.go +++ b/aigateway/component/openai_test.go @@ -57,10 +57,10 @@ func TestGetSceneFromSvcType(t *testing.T) { func TestFilterAndPaginateModels(t *testing.T) { models := []types.Model{ - {BaseModel: types.BaseModel{ID: "gpt-4:svc1", Object: "model", OwnedBy: "u1", Public: true}}, - {BaseModel: types.BaseModel{ID: "gpt-3.5:svc2", Object: "model", OwnedBy: "u1", Public: false}}, - {BaseModel: types.BaseModel{ID: "claude:svc3", Object: "model", OwnedBy: "u2", Public: true}}, - {BaseModel: types.BaseModel{ID: "gpt-4o:svc4", Object: "model", OwnedBy: "u3", Public: true}}, + {BaseModel: types.BaseModel{ID: "gpt-4:svc1", Object: "model", OwnedBy: "u1"}}, + {BaseModel: types.BaseModel{ID: "gpt-3.5:svc2", Object: "model", OwnedBy: "u1"}}, + {BaseModel: types.BaseModel{ID: "claude:svc3", Object: "model", OwnedBy: "u2"}}, + {BaseModel: types.BaseModel{ID: "gpt-4o:svc4", Object: "model", OwnedBy: "u3"}}, } t.Run("no filters default pagination", func(t *testing.T) { @@ -81,20 +81,6 @@ func TestFilterAndPaginateModels(t *testing.T) { assert.Len(t, resp.Data, 3) }) - t.Run("public filter parses bool", func(t *testing.T) { - resp := filterAndPaginateModels(models, types.ListModelsReq{Public: "false"}) - assert.Equal(t, 1, resp.TotalCount) - require.Len(t, resp.Data, 1) - assert.False(t, resp.Data[0].Public) - assert.Equal(t, "gpt-3.5:svc2", resp.Data[0].ID) - }) - - t.Run("invalid public filter is ignored", func(t *testing.T) { - resp := filterAndPaginateModels(models, types.ListModelsReq{Public: "notabool"}) - assert.Equal(t, 4, resp.TotalCount) - assert.Len(t, resp.Data, 4) - }) - t.Run("pagination per/page applied after filters", func(t *testing.T) { resp := filterAndPaginateModels(models, types.ListModelsReq{ModelID: "gpt", Per: "2", Page: "2"}) // gpt matches 3 models; page=2 per=2 yields 1 item @@ -106,9 +92,9 @@ func TestFilterAndPaginateModels(t *testing.T) { t.Run("source filter csghub", func(t *testing.T) { modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1", Public: true}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "external-model", Object: "model", OwnedBy: "openai", Public: true}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, - {BaseModel: types.BaseModel{ID: "csghub-model:svc2", Object: "model", OwnedBy: "u2", Public: false}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "org/model2"}}, + {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, + {BaseModel: types.BaseModel{ID: "external-model", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + {BaseModel: types.BaseModel{ID: "csghub-model:svc2", Object: "model", OwnedBy: "u2"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "org/model2"}}, } resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: string(types.ModelSourceCSGHub)}) assert.Equal(t, 2, resp.TotalCount) @@ -119,9 +105,9 @@ func TestFilterAndPaginateModels(t *testing.T) { t.Run("source filter external", func(t *testing.T) { modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1", Public: true}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai", Public: true}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, - {BaseModel: types.BaseModel{ID: "claude", Object: "model", OwnedBy: "anthropic", Public: true}, ExternalModelInfo: types.ExternalModelInfo{Provider: "anthropic"}}, + {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, + {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + {BaseModel: types.BaseModel{ID: "claude", Object: "model", OwnedBy: "anthropic"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "anthropic"}}, } resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: string(types.ModelSourceExternal)}) assert.Equal(t, 2, resp.TotalCount) @@ -132,8 +118,8 @@ func TestFilterAndPaginateModels(t *testing.T) { t.Run("source filter is case-insensitive", func(t *testing.T) { modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1", Public: true}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai", Public: true}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, + {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, } resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: "CSGHub"}) assert.Equal(t, 1, resp.TotalCount) @@ -143,24 +129,110 @@ func TestFilterAndPaginateModels(t *testing.T) { t.Run("unknown source filter includes all", func(t *testing.T) { modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1", Public: true}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai", Public: true}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + {BaseModel: types.BaseModel{ID: "csghub-model:svc1", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, + {BaseModel: types.BaseModel{ID: "gpt-4", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, } resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: "unknown"}) assert.Equal(t, 2, resp.TotalCount) assert.Len(t, resp.Data, 2) }) - t.Run("source filter combined with public filter", func(t *testing.T) { + t.Run("source filter csghub includes public and private deployments", func(t *testing.T) { modelsWithSource := []types.Model{ - {BaseModel: types.BaseModel{ID: "csghub-public", Object: "model", OwnedBy: "u1", Public: true}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, - {BaseModel: types.BaseModel{ID: "csghub-private", Object: "model", OwnedBy: "u1", Public: false}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model2"}}, - {BaseModel: types.BaseModel{ID: "external-public", Object: "model", OwnedBy: "openai", Public: true}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + {BaseModel: types.BaseModel{ID: "csghub-public", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, + {BaseModel: types.BaseModel{ID: "csghub-private", Object: "model", OwnedBy: "u1"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model2"}}, + {BaseModel: types.BaseModel{ID: "external-public", Object: "model", OwnedBy: "openai"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + } + resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: string(types.ModelSourceCSGHub)}) + assert.Equal(t, 2, resp.TotalCount) + assert.Len(t, resp.Data, 2) + assert.Equal(t, "csghub-public", resp.Data[0].ID) + assert.Equal(t, "csghub-private", resp.Data[1].ID) + }) + + t.Run("task filter text-generation", func(t *testing.T) { + modelsWithTask := []types.Model{ + {BaseModel: types.BaseModel{ID: "model-1", Object: "model", OwnedBy: "u1", Task: "text-generation"}}, + {BaseModel: types.BaseModel{ID: "model-2", Object: "model", OwnedBy: "u1", Task: "text-to-image"}}, + {BaseModel: types.BaseModel{ID: "model-3", Object: "model", OwnedBy: "u2", Task: "text-generation"}}, + } + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "text-generation"}) + assert.Equal(t, 2, resp.TotalCount) + assert.Len(t, resp.Data, 2) + assert.Equal(t, "model-1", resp.Data[0].ID) + assert.Equal(t, "model-3", resp.Data[1].ID) + }) + + t.Run("task filter text-to-image", func(t *testing.T) { + modelsWithTask := []types.Model{ + {BaseModel: types.BaseModel{ID: "model-1", Object: "model", OwnedBy: "u1", Task: "text-generation"}}, + {BaseModel: types.BaseModel{ID: "model-2", Object: "model", OwnedBy: "u1", Task: "text-to-image"}}, + {BaseModel: types.BaseModel{ID: "model-3", Object: "model", OwnedBy: "u2", Task: "text-generation"}}, } - resp := filterAndPaginateModels(modelsWithSource, types.ListModelsReq{Source: string(types.ModelSourceCSGHub), Public: "true"}) + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "text-to-image"}) assert.Equal(t, 1, resp.TotalCount) assert.Len(t, resp.Data, 1) - assert.Equal(t, "csghub-public", resp.Data[0].ID) + assert.Equal(t, "model-2", resp.Data[0].ID) + }) + + t.Run("task filter is case-insensitive", func(t *testing.T) { + modelsWithTask := []types.Model{ + {BaseModel: types.BaseModel{ID: "model-1", Object: "model", OwnedBy: "u1", Task: "Text-Generation"}}, + {BaseModel: types.BaseModel{ID: "model-2", Object: "model", OwnedBy: "u1", Task: "TEXT-TO-IMAGE"}}, + } + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "text-generation"}) + assert.Equal(t, 1, resp.TotalCount) + assert.Len(t, resp.Data, 1) + assert.Equal(t, "model-1", resp.Data[0].ID) + }) + + t.Run("task filter with no matches", func(t *testing.T) { + modelsWithTask := []types.Model{ + {BaseModel: types.BaseModel{ID: "model-1", Object: "model", OwnedBy: "u1", Task: "text-generation"}}, + {BaseModel: types.BaseModel{ID: "model-2", Object: "model", OwnedBy: "u1", Task: "text-to-image"}}, + } + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "non-existent-task"}) + assert.Equal(t, 0, resp.TotalCount) + assert.Len(t, resp.Data, 0) + }) + + t.Run("task filter matches any comma-separated task on model", func(t *testing.T) { + modelsWithTask := []types.Model{ + {BaseModel: types.BaseModel{ID: "multi", Object: "model", OwnedBy: "u1", Task: "text-generation,text-to-image,summarization"}}, + {BaseModel: types.BaseModel{ID: "single", Object: "model", OwnedBy: "u1", Task: "text-to-image"}}, + {BaseModel: types.BaseModel{ID: "other", Object: "model", OwnedBy: "u1", Task: "embedding"}}, + } + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "text-generation"}) + assert.Equal(t, 1, resp.TotalCount) + require.Len(t, resp.Data, 1) + assert.Equal(t, "multi", resp.Data[0].ID) + + resp = filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "summarization"}) + assert.Equal(t, 1, resp.TotalCount) + require.Len(t, resp.Data, 1) + assert.Equal(t, "multi", resp.Data[0].ID) + }) + + t.Run("task filter collapses repeated commas in model Task", func(t *testing.T) { + modelsWithTask := []types.Model{ + {BaseModel: types.BaseModel{ID: "sparse", Object: "model", OwnedBy: "u1", Task: "text-generation,,text-to-image"}}, + } + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "text-generation"}) + assert.Equal(t, 1, resp.TotalCount) + resp = filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Task: "text-to-image"}) + assert.Equal(t, 1, resp.TotalCount) + }) + + t.Run("task filter combined with source filter", func(t *testing.T) { + modelsWithTask := []types.Model{ + {BaseModel: types.BaseModel{ID: "csghub-gen", Object: "model", OwnedBy: "u1", Task: "text-generation"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model1"}}, + {BaseModel: types.BaseModel{ID: "csghub-image", Object: "model", OwnedBy: "u1", Task: "text-to-image"}, InternalModelInfo: types.InternalModelInfo{CSGHubModelID: "user/model2"}}, + {BaseModel: types.BaseModel{ID: "external-gen", Object: "model", OwnedBy: "openai", Task: "text-generation"}, ExternalModelInfo: types.ExternalModelInfo{Provider: "openai"}}, + } + resp := filterAndPaginateModels(modelsWithTask, types.ListModelsReq{Source: string(types.ModelSourceCSGHub), Task: "text-generation"}) + assert.Equal(t, 1, resp.TotalCount) + assert.Len(t, resp.Data, 1) + assert.Equal(t, "csghub-gen", resp.Data[0].ID) }) } diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index 5bf9fd4b..e6fa4fa2 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -140,13 +140,13 @@ type OpenAIHandlerImpl struct { // ListModels godoc // @Summary List available models -// @Description Returns a list of available models, supports fuzzy search by model_id query parameter and filtering by public status +// @Description Returns a list of available models, supports fuzzy search by model_id query parameter and filtering by source and task // @Tags AIGateway // @Accept json // @Produce json // @Param model_id query string false "Model ID for fuzzy search" -// @Param public query bool false "Filter by public status (true for public models, false for private models)" // @Param source query string false "Filter by source (csghub for CSGHub models, external for external models)" Enums(csghub, external) +// @Param task query string false "Filter by task (e.g., text-generation, text-to-image)" // @Param per query int false "Models per page (default 20, max 100)" // @Param page query int false "Page number (1-based, default 1)" // @Success 200 {object} types.ModelList "OK" @@ -173,8 +173,8 @@ func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { resp, err := h.openaiComponent.ListModels(c.Request.Context(), currentUser, types.ListModelsReq{ ModelID: c.Query("model_id"), - Public: c.Query("public"), Source: source, + Task: c.Query("task"), Per: c.Query("per"), Page: c.Query("page"), }) @@ -206,6 +206,7 @@ func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { func (h *OpenAIHandlerImpl) GetModel(c *gin.Context) { username := httpbase.GetCurrentUser(c) modelID := c.Param("model") + modelID = strings.TrimPrefix(modelID, "/") if modelID == "" { c.JSON(http.StatusBadRequest, gin.H{ "error": types.Error{ diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index 88c1ee94..0847507c 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -75,7 +75,7 @@ func TestOpenAIHandler_ListModels(t *testing.T) { t.Run("successful passthrough", func(t *testing.T) { tester, c, w := setupTest(t) models := []types.Model{ - {BaseModel: types.BaseModel{ID: "model1:svc1", Object: "model", OwnedBy: "testuser", Public: true}}, + {BaseModel: types.BaseModel{ID: "model1:svc1", Object: "model", OwnedBy: "testuser"}}, } expect := types.ModelList{ Object: "list", @@ -102,14 +102,12 @@ func TestOpenAIHandler_ListModels(t *testing.T) { tester, c, w := setupTest(t) tester.WithQuery("model_id", "gpt"). - WithQuery("public", "true"). WithQuery("per", "2"). WithQuery("page", "3") tester.mocks.openAIComp.EXPECT(). ListModels(mock.Anything, "testuser", types.ListModelsReq{ ModelID: "gpt", - Public: "true", Per: "2", Page: "3", }). @@ -120,6 +118,36 @@ func TestOpenAIHandler_ListModels(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) }) + t.Run("passes task query param to component", func(t *testing.T) { + tester, c, w := setupTest(t) + + tester.WithQuery("task", "text-generation"). + WithQuery("per", "10"). + WithQuery("page", "1") + + tester.mocks.openAIComp.EXPECT(). + ListModels(mock.Anything, "testuser", types.ListModelsReq{ + Task: "text-generation", + Per: "10", + Page: "1", + }). + Return(types.ModelList{ + Object: "list", + Data: []types.Model{{BaseModel: types.BaseModel{ID: "model1", Task: "text-generation"}}}, + HasMore: false, + TotalCount: 1, + }, nil).Once() + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, 1, response.TotalCount) + assert.Equal(t, "text-generation", response.Data[0].Task) + }) + t.Run("component error", func(t *testing.T) { tester, c, w := setupTest(t) tester.mocks.openAIComp.EXPECT(). @@ -199,7 +227,6 @@ func TestOpenAIHandler_ListModels_OpenaiSDK(t *testing.T) { ID: "gpt-4:svc1", Object: "model", OwnedBy: "testuser", - Public: true, }, }, { @@ -207,7 +234,6 @@ func TestOpenAIHandler_ListModels_OpenaiSDK(t *testing.T) { ID: "gpt-3.5-turbo:svc2", Object: "model", OwnedBy: "testuser", - Public: true, }, }, } @@ -301,6 +327,49 @@ func TestOpenAIHandler_GetModel(t *testing.T) { assert.Equal(t, http.StatusBadRequest, w.Code) }) + + t.Run("model with slash in name - trims leading slash", func(t *testing.T) { + tester, c, w := setupTest(t) + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "xzgan001/gguf_model:fepjlx3v39xc", + Object: "model", + OwnedBy: "testuser", + }, + } + // Wildcard route adds leading slash + c.Params = []gin.Param{{Key: "model", Value: "/xzgan001/gguf_model:fepjlx3v39xc"}} + tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "xzgan001/gguf_model:fepjlx3v39xc").Return(model, nil) + + tester.handler.GetModel(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.Model + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "xzgan001/gguf_model:fepjlx3v39xc", response.ID) + }) + + t.Run("model without leading slash - no trim needed", func(t *testing.T) { + tester, c, w := setupTest(t) + model := &types.Model{ + BaseModel: types.BaseModel{ + ID: "simple-model:svc1", + Object: "model", + OwnedBy: "testuser", + }, + } + c.Params = []gin.Param{{Key: "model", Value: "simple-model:svc1"}} + tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "simple-model:svc1").Return(model, nil) + + tester.handler.GetModel(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.Model + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "simple-model:svc1", response.ID) + }) } func TestOpenAIHandler_Chat(t *testing.T) { diff --git a/aigateway/router/aigateway.go b/aigateway/router/aigateway.go index c20ed51b..148ccd71 100644 --- a/aigateway/router/aigateway.go +++ b/aigateway/router/aigateway.go @@ -2,7 +2,6 @@ package router import ( "fmt" - "net/http" "opencsg.com/csghub-server/builder/instrumentation" @@ -37,8 +36,6 @@ func NewRouter(config *config.Config) (*gin.Engine, func(), error) { r.Use(middleware.BuildJwtSession(config.JWT.SigningKey)) i18n.InitLocalizersFromEmbedFile() r.Use(middleware.ModifyAcceptLanguageMiddleware(), middleware.LocalizedErrorMiddleware()) - i18n.InitLocalizersFromEmbedFile() - r.Use(middleware.ModifyAcceptLanguageMiddleware(), middleware.LocalizedErrorMiddleware()) r.Use(middleware.Authenticator(config)) middlewareCollection := middleware.MiddlewareCollection{} middlewareCollection.Auth.NeedLogin = middleware.MustLogin() @@ -54,7 +51,7 @@ func NewRouter(config *config.Config) (*gin.Engine, func(), error) { return nil, nil, fmt.Errorf("error creating openai handler :%w", err) } v1Group.GET("/models", openAIhandler.ListModels) - v1Group.GET("/models/:model", middlewareCollection.Auth.NeedLogin, openAIhandler.GetModel) + v1Group.GET("/models/*model", middlewareCollection.Auth.NeedLogin, openAIhandler.GetModel) v1Group.POST("/chat/completions", middlewareCollection.Auth.NeedLogin, openAIhandler.Chat) v1Group.POST("/embeddings", middlewareCollection.Auth.NeedLogin, openAIhandler.Embedding) v1Group.POST("/images/generations", middlewareCollection.Auth.NeedLogin, openAIhandler.GenerateImage) diff --git a/aigateway/types/openai.go b/aigateway/types/openai.go index adda10ef..d7f36ca9 100644 --- a/aigateway/types/openai.go +++ b/aigateway/types/openai.go @@ -6,6 +6,20 @@ import ( "opencsg.com/csghub-server/common/types" ) +// Provider type values for Metadata[MetaKeyLLMType]. +const ( + ProviderTypeServerless = "serverless" + ProviderTypeInference = "inference" + ProviderTypeExternalLLM = "external_llm" +) + +// Metadata key constants used when enriching model metadata. +const ( + MetaKeyLLMType = "llm_type" + MetaKeyPricing = "pricing" + MetaKeyTasks = "tasks" +) + // BaseModel represents the base model fields type BaseModel struct { ID string `json:"id"` @@ -16,7 +30,7 @@ type BaseModel struct { DisplayName string `json:"display_name"` SupportFunctionCall bool `json:"support_function_call,omitempty"` // whether the model supports function calling IsPinned *bool `json:"is_pinned,omitempty"` // whether the model is pinned - Public bool `json:"public"` // whether the model is public (false = private, true = public) + Public bool `json:"public,omitempty"` Metadata map[string]any `json:"metadata"` } @@ -61,7 +75,7 @@ func (m Model) MarshalJSON() ([]byte, error) { Task string `json:"task"` DisplayName string `json:"display_name"` SupportFunctionCall *bool `json:"support_function_call,omitempty"` - Public bool `json:"public"` + Public bool `json:"public,omitempty"` Endpoint string `json:"endpoint"` Metadata map[string]any `json:"metadata"` ClusterID *string `json:"cluster_id,omitempty"` @@ -119,7 +133,7 @@ func (m *Model) UnmarshalJSON(data []byte) error { Task string `json:"task"` DisplayName string `json:"display_name"` SupportFunctionCall bool `json:"support_function_call,omitempty"` - Public bool `json:"public"` + Public bool `json:"public,omitempty"` Endpoint string `json:"endpoint"` Metadata map[string]any `json:"metadata"` ClusterID string `json:"cluster_id,omitempty"` @@ -139,8 +153,8 @@ func (m *Model) UnmarshalJSON(data []byte) error { m.OwnedBy = aux.OwnedBy m.Task = aux.Task m.DisplayName = aux.DisplayName - m.SupportFunctionCall = aux.SupportFunctionCall m.Public = aux.Public + m.SupportFunctionCall = aux.SupportFunctionCall m.Endpoint = aux.Endpoint m.Metadata = aux.Metadata m.ClusterID = aux.ClusterID @@ -193,10 +207,10 @@ type ModelList struct { // filtering, and pagination behavior consistently. type ListModelsReq struct { ModelID string `json:"model_id"` - Public string `json:"public"` Per string `json:"per"` Page string `json:"page"` Source string `json:"source"` // filter by source (csghub for CSGHub models, external for external models) + Task string `json:"task"` // filter by task } // UserPreferenceRequest defines the request parameters for UserPreference method @@ -223,3 +237,16 @@ const ( // ModelSourceExternal represents models from external providers ModelSourceExternal ModelSource = "external" ) + +// ModelTokenPrice is currency plus per-million-token rate (major units, from accounting cents + sku_unit). +type ModelTokenPrice struct { + Currency string `json:"currency,omitempty"` + PricePerMillion float64 `json:"price_per_million,omitempty"` +} + +// ModelScenePrice is Metadata["pricing"]: serverless sets input/output token prices; external_llm sets token_price. +type ModelScenePrice struct { + InputTokenPrice *ModelTokenPrice `json:"input_token_price,omitempty"` + OutputTokenPrice *ModelTokenPrice `json:"output_token_price,omitempty"` + TokenPrice *ModelTokenPrice `json:"token_price,omitempty"` +} diff --git a/aigateway/types/openai_test.go b/aigateway/types/openai_test.go index 8fcae8a9..62d3b5eb 100644 --- a/aigateway/types/openai_test.go +++ b/aigateway/types/openai_test.go @@ -18,7 +18,6 @@ func TestModelSerialization(t *testing.T) { Task: "text-generation", SupportFunctionCall: true, - Public: true, }, InternalModelInfo: InternalModelInfo{ CSGHubModelID: "test/repo/path", @@ -48,8 +47,8 @@ func TestModelSerialization(t *testing.T) { t.Errorf("External response should not contain sensitive fields, got: %s", jsonStr) } - if !contains(jsonStr, "test-model") || !contains(jsonStr, "model") || !contains(jsonStr, "test-owner") || !contains(jsonStr, "public") { - t.Errorf("External response should contain BaseModel fields including public, got: %s", jsonStr) + if !contains(jsonStr, "test-model") || !contains(jsonStr, "model") || !contains(jsonStr, "test-owner") { + t.Errorf("External response should contain BaseModel fields, got: %s", jsonStr) } }) @@ -61,8 +60,8 @@ func TestModelSerialization(t *testing.T) { t.Fatalf("Failed to marshal model in internal use mode: %v", err) } jsonStr := string(jsonData) - if !contains(jsonStr, "endpoint") || !contains(jsonStr, "http://test-endpoint.com") || !contains(jsonStr, "test-model") || !contains(jsonStr, "public") { - t.Errorf("Internal response should contain base fields including public, got: %s", jsonStr) + if !contains(jsonStr, "endpoint") || !contains(jsonStr, "http://test-endpoint.com") || !contains(jsonStr, "test-model") { + t.Errorf("Internal response should contain base fields, got: %s", jsonStr) } if contains(jsonStr, "internal_model_info") { @@ -109,7 +108,6 @@ func TestModelListSerialization(t *testing.T) { BaseModel: BaseModel{ ID: "model-1", Object: "model", - Public: true, }, Endpoint: "http://model-1.com", InternalUse: false, @@ -152,7 +150,6 @@ func TestModelUnmarshal(t *testing.T) { "owned_by": "test-owner", "task": "text-generation", "support_function_call": true, - "public": true, "endpoint": "http://model-1.com", "internal_use": false } diff --git a/builder/accounting/client.go b/builder/accounting/client.go index ed28d135..002aa6ed 100644 --- a/builder/accounting/client.go +++ b/builder/accounting/client.go @@ -62,7 +62,6 @@ func (ac *accountingClientImpl) ListMeteringsByUserIDAndTime(req types.ActStatem // Helper method to execute the actual HTTP request and read the response. func (ac *accountingClientImpl) doRequest(method, subPath string, data any) (*http.Response, error) { urlPath := fmt.Sprintf("%s%s%s", ac.remote, "/api/v1/accounting", subPath) - // slog.Info("call", slog.Any("urlPath", urlPath)) var buf io.Reader if data != nil { jsonData, err := json.Marshal(data) diff --git a/builder/store/database/llm_config.go b/builder/store/database/llm_config.go index 6489924f..099dc8a4 100644 --- a/builder/store/database/llm_config.go +++ b/builder/store/database/llm_config.go @@ -165,5 +165,8 @@ func buildSearchLLMConfigQuery( if search.Type != nil { q.Where("llm_config.type = ?", *search.Type) } - + // Filter by Enabled if provided + if search.Enabled != nil { + q.Where("llm_config.enabled = ?", *search.Enabled) + } } diff --git a/builder/store/database/llm_config_test.go b/builder/store/database/llm_config_test.go index 7aeaf3a4..989b4149 100644 --- a/builder/store/database/llm_config_test.go +++ b/builder/store/database/llm_config_test.go @@ -124,14 +124,14 @@ func TestLLMConfigStore_CRUD(t *testing.T) { Enabled: true, ModelName: "summary1", DisplayName: "summary1", - Metadata: map[string]any{"k": "v"}, + Metadata: map[string]any{"k": "v", "tasks": []interface{}{"text-generation", "text-to-image"}}, } res, err := store.Create(ctx, dbInput) require.Nil(t, err) require.NotNil(t, res) require.Equal(t, "summary1", res.ModelName) require.Equal(t, "summary1", res.DisplayName) - require.Equal(t, map[string]any{"k": "v"}, res.Metadata) + require.Equal(t, map[string]any{"k": "v", "tasks": []interface{}{"text-generation", "text-to-image"}}, res.Metadata) searchType := 5 search := &types.SearchLLMConfig{ @@ -232,3 +232,79 @@ func TestLLMConfigStore_Search(t *testing.T) { } require.True(t, found, "Should find gpt-4 when searching for gpt-4") } + +func TestLLMConfigStore_Index_EnabledFilter(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + config, err := config.LoadConfig() + require.Nil(t, err) + store := database.NewLLMConfigStoreWithDB(db, config) + + searchType := 16 + base := database.LLMConfig{ + Type: searchType, + ApiEndpoint: "https://example.test/v1", + AuthHeader: "{}", + Provider: "test", + } + _, err = store.Create(ctx, database.LLMConfig{ + ModelName: "idx-en-on", + DisplayName: "idx-en-on", + Enabled: true, + Type: base.Type, + ApiEndpoint: base.ApiEndpoint, + AuthHeader: base.AuthHeader, + Provider: base.Provider, + }) + require.Nil(t, err) + _, err = store.Create(ctx, database.LLMConfig{ + ModelName: "idx-en-off", + DisplayName: "idx-en-off", + Enabled: false, + Type: base.Type, + ApiEndpoint: base.ApiEndpoint, + AuthHeader: base.AuthHeader, + Provider: base.Provider, + }) + require.Nil(t, err) + + enabledTrue := true + enabledFalse := false + + cfgsOn, totalOn, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Type: &searchType, + Enabled: &enabledTrue, + }) + require.Nil(t, err) + require.Equal(t, 1, totalOn) + require.Len(t, cfgsOn, 1) + require.Equal(t, "idx-en-on", cfgsOn[0].ModelName) + require.True(t, cfgsOn[0].Enabled) + + cfgsOff, totalOff, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Type: &searchType, + Enabled: &enabledFalse, + }) + require.Nil(t, err) + require.Equal(t, 1, totalOff) + require.Len(t, cfgsOff, 1) + require.Equal(t, "idx-en-off", cfgsOff[0].ModelName) + require.False(t, cfgsOff[0].Enabled) + + cfgsBoth, totalBoth, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Type: &searchType, + }) + require.Nil(t, err) + require.Equal(t, 2, totalBoth) + require.Len(t, cfgsBoth, 2) + + cfgsKeyword, totalKeyword, err := store.Index(ctx, 20, 1, &types.SearchLLMConfig{ + Keyword: "idx-en-", + Enabled: &enabledTrue, + }) + require.Nil(t, err) + require.Equal(t, 1, totalKeyword) + require.Len(t, cfgsKeyword, 1) + require.Equal(t, "idx-en-on", cfgsKeyword[0].ModelName) +} diff --git a/common/types/llm_service.go b/common/types/llm_service.go index 09614f55..e42d3a35 100644 --- a/common/types/llm_service.go +++ b/common/types/llm_service.go @@ -11,7 +11,7 @@ type LLMConfig struct { Type int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme Enabled bool `json:"enabled"` Provider string `json:"provider"` - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` // tasks stored as: {"tasks": ["text-generation", "text-to-image"]} CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -26,6 +26,7 @@ type PromptPrefix struct { type SearchLLMConfig struct { Keyword string `json:"keyword"` // Search keyword Type *int `json:"type"` // Type of search + Enabled *bool `json:"enabled"` // Enabled filter } type SearchPromptPrefix struct { @@ -42,7 +43,7 @@ type UpdateLLMConfigReq struct { Type *int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme Enabled *bool `json:"enabled"` Provider *string `json:"provider"` - Metadata *map[string]any `json:"metadata"` + Metadata *map[string]any `json:"metadata"` // tasks stored as: {"tasks": ["text-generation", "text-to-image"]} } type UpdatePromptPrefixReq struct { @@ -53,15 +54,16 @@ type UpdatePromptPrefixReq struct { } type CreateLLMConfigReq struct { - ModelName string `json:"model_name"` + ModelName string `json:"model_name" binding:"required"` DisplayName string `json:"display_name"` - ApiEndpoint string `json:"api_endpoint"` + ApiEndpoint string `json:"api_endpoint" binding:"required"` AuthHeader string `json:"auth_header"` - Type int `json:"type"` // 1: optimization, 2: comparison, 4: summary readme - Provider string `json:"provider"` + Type int `json:"type" binding:"required,oneof=1 2 4 8 16"` // 1: optimization, 2: comparison, 4: summary readme, 8: mcp scan, 16: for aigateway call external llm + Provider string `json:"provider" binding:"required"` Enabled bool `json:"enabled"` - Metadata map[string]any `json:"metadata"` + Metadata map[string]any `json:"metadata"` // tasks stored as: {"tasks": ["text-generation", "text-to-image"]} } + type CreatePromptPrefixReq struct { ZH string `json:"zh"` EN string `json:"en"` diff --git a/component/llm_service_test.go b/component/llm_service_test.go index e8a89814..d196b156 100644 --- a/component/llm_service_test.go +++ b/component/llm_service_test.go @@ -21,23 +21,29 @@ func TestLLMServiceComponent_CreateLLMConfig(t *testing.T) { ModelName: "new-model", ApiEndpoint: "http://new.endpoint", AuthHeader: "Bearer token", - Type: 666, + Type: 16, Enabled: true, + Provider: "test-provider", + Metadata: map[string]any{"tasks": []any{"text-generation"}}, } dbLLMConfig := &database.LLMConfig{ ID: 123, ModelName: "new-model", ApiEndpoint: "http://new.endpoint", AuthHeader: "Bearer token", - Type: 666, + Type: 16, Enabled: true, + Provider: "test-provider", + Metadata: map[string]any{"tasks": []any{"text-generation"}}, } stores.LLMConfigMock().EXPECT().Create(ctx, database.LLMConfig{ ModelName: "new-model", ApiEndpoint: "http://new.endpoint", AuthHeader: "Bearer token", - Type: 666, + Type: 16, Enabled: true, + Provider: "test-provider", + Metadata: map[string]any{"tasks": []any{"text-generation"}}, }).Return(dbLLMConfig, nil) res, err := mc.CreateLLMConfig(ctx, req) require.Nil(t, err) @@ -132,18 +138,22 @@ func TestLLMServiceComponent_UpdateLLMConfig(t *testing.T) { promptPrefixStore: stores.PromptPrefix, } newName := "new-model" + metadata := map[string]any{"tasks": []any{"text-to-image"}} req := &types.UpdateLLMConfigReq{ - ID: 123, + ID: 123, ModelName: &newName, + Metadata: &metadata, } dbLLMConfig := &database.LLMConfig{ ID: 123, ModelName: newName, + Metadata: metadata, } stores.LLMConfigMock().EXPECT().GetByID(ctx, int64(123)).Return(dbLLMConfig, nil) stores.LLMConfigMock().EXPECT().Update(ctx, database.LLMConfig{ ID: 123, ModelName: newName, + Metadata: metadata, }).Return(dbLLMConfig, nil) res, err := mc.UpdateLLMConfig(ctx, req) require.Nil(t, err)