Skip to content

Commit 6eaa946

Browse files
committed
Fix: Failure to update local model configuration to gateway
1 parent dcf1870 commit 6eaa946

8 files changed

Lines changed: 259 additions & 38 deletions

File tree

controller/ai-local/iml.go

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@ import (
77
"io"
88
"math"
99
"net/http"
10+
"net/url"
1011
"strings"
12+
"time"
1113

12-
system_dto "github.com/APIParkLab/APIPark/module/system/dto"
14+
ai_balance "github.com/APIParkLab/APIPark/module/ai-balance"
1315

1416
"github.com/APIParkLab/APIPark/module/system"
17+
system_dto "github.com/APIParkLab/APIPark/module/system/dto"
18+
ollama_api "github.com/ollama/ollama/api"
1519

1620
"github.com/APIParkLab/APIPark/module/subscribe"
1721
subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto"
@@ -51,6 +55,7 @@ type imlLocalModelController struct {
5155
serviceModule service.IServiceModule `autowired:""`
5256
catalogueModule catalogue.ICatalogueModule `autowired:""`
5357
aiAPIModule ai_api.IAPIModule `autowired:""`
58+
aiBalanceModule ai_balance.IBalanceModule `autowired:""`
5459
appModule service.IAppModule `autowired:""`
5560
routerModule router.IRouterModule `autowired:""`
5661
subscribeModule subscribe.ISubscribeModule `autowired:""`
@@ -66,9 +71,35 @@ func (i *imlLocalModelController) OllamaConfig(ctx *gin.Context) (*ai_local_dto.
6671
}, nil
6772
}
6873

74+
var (
75+
client = &http.Client{
76+
Timeout: 2 * time.Second,
77+
}
78+
)
79+
6980
func (i *imlLocalModelController) OllamaConfigUpdate(ctx *gin.Context, input *ai_local_dto.OllamaConfig) error {
70-
return i.settingModule.Set(ctx, &system_dto.InputSetting{
71-
OllamaAddress: &input.Address,
81+
u, err := url.Parse(input.Address)
82+
if err != nil {
83+
return nil
84+
}
85+
ollamaClient := ollama_api.NewClient(u, client)
86+
_, err = ollamaClient.Version(ctx)
87+
if err != nil {
88+
return err
89+
}
90+
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
91+
err = i.module.SyncLocalModels(ctx, input.Address)
92+
if err != nil {
93+
return err
94+
}
95+
err = i.aiBalanceModule.SyncLocalBalances(ctx, input.Address)
96+
if err != nil {
97+
return err
98+
}
99+
100+
return i.settingModule.Set(ctx, &system_dto.InputSetting{
101+
OllamaAddress: &input.Address,
102+
})
72103
})
73104
}
74105

controller/service/iml.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,13 +477,13 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se
477477
}
478478
var err error
479479
var info *service_dto.Service
480-
err = i.transaction.Transaction(ctx, func(txCtx context.Context) error {
481-
info, err = i.module.Create(txCtx, teamID, input)
480+
err = i.transaction.Transaction(ctx, func(ctx context.Context) error {
481+
info, err = i.module.Create(ctx, teamID, input)
482482
if err != nil {
483483
return err
484484
}
485485
path := fmt.Sprintf("/%s/", strings.Trim(input.Prefix, "/"))
486-
_, err = i.routerModule.Create(txCtx, info.Id, &router_dto.Create{
486+
_, err = i.routerModule.Create(ctx, info.Id, &router_dto.Create{
487487
Id: uuid.New().String(),
488488
Name: "",
489489
Path: path + "*",
@@ -499,6 +499,15 @@ func (i *imlServiceController) Create(ctx *gin.Context, teamID string, input *se
499499
},
500500
Disable: false,
501501
})
502+
apps, err := i.appModule.Search(ctx, teamID, "")
503+
if err != nil {
504+
return err
505+
}
506+
for _, app := range apps {
507+
i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{
508+
Application: app.Id,
509+
})
510+
}
502511
return err
503512
})
504513
return info, err

controller/system/iml.go

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import (
1010
"strings"
1111
"time"
1212

13+
subscribe_dto "github.com/APIParkLab/APIPark/module/subscribe/dto"
14+
1315
"github.com/eolinker/eosc/log"
1416

1517
ai_dto "github.com/APIParkLab/APIPark/module/ai/dto"
@@ -222,6 +224,7 @@ type imlInitController struct {
222224
applicationAuthorizationModule application_authorization.IAuthorizationModule `autowired:""`
223225
catalogueModule catalogue.ICatalogueModule `autowired:""`
224226
providerModule ai.IProviderModule `autowired:""`
227+
subscribeModule subscribe.ISubscribeModule `autowired:""`
225228
transaction store.ITransaction `autowired:""`
226229
aiAPIModule ai_api.IAPIModule `autowired:""`
227230
docModule service.IServiceDocModule `autowired:""`
@@ -264,6 +267,13 @@ func (i *imlInitController) OnInit() {
264267
if err != nil {
265268
return fmt.Errorf("create default team error: %v", err)
266269
}
270+
app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{
271+
Name: "Demo Application",
272+
Description: "Auto created By APIPark",
273+
})
274+
if err != nil {
275+
return fmt.Errorf("create default app error: %v", err)
276+
}
267277
// 创建Rest服务
268278
restPath := "/rest-demo"
269279
serviceInfo, err := i.serviceModule.Create(ctx, info.Id, &service_dto.CreateService{
@@ -298,6 +308,13 @@ func (i *imlInitController) OnInit() {
298308
if err != nil {
299309
return fmt.Errorf("create default router error: %v", err)
300310
}
311+
err = i.subscribeModule.AddSubscriber(ctx, serviceInfo.Id, &subscribe_dto.AddSubscriber{
312+
Application: app.Id,
313+
})
314+
if err != nil {
315+
return err
316+
}
317+
301318
// 创建AI服务
302319
err = i.createAIService(ctx, info.Id, &service_dto.CreateService{
303320
Name: "AI Demo Service",
@@ -307,17 +324,11 @@ func (i *imlInitController) OnInit() {
307324
Catalogue: catalogueId,
308325
ApprovalType: "auto",
309326
Kind: "ai",
310-
})
327+
}, app.Id)
311328
if err != nil {
312329
return err
313330
}
314-
app, err := i.appModule.CreateApp(ctx, info.Id, &service_dto.CreateApp{
315-
Name: "Demo Application",
316-
Description: "Auto created By APIPark",
317-
})
318-
if err != nil {
319-
return fmt.Errorf("create default app error: %v", err)
320-
}
331+
321332
_, err = i.applicationAuthorizationModule.AddAuthorization(ctx, app.Id, &application_authorization_dto.CreateAuthorization{
322333
Name: "Default API Key",
323334
Driver: "apikey",
@@ -338,7 +349,7 @@ func (i *imlInitController) OnInit() {
338349
}
339350
})
340351
}
341-
func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService) error {
352+
func (i *imlInitController) createAIService(ctx context.Context, teamID string, input *service_dto.CreateService, appId string) error {
342353

343354
providerId := "fakegpt"
344355
err := i.providerModule.UpdateProviderConfig(ctx, providerId, &ai_dto.UpdateConfig{
@@ -469,6 +480,12 @@ func (i *imlInitController) createAIService(ctx context.Context, teamID string,
469480
if err != nil {
470481
return err
471482
}
483+
err = i.subscribeModule.AddSubscriber(ctx, info.Id, &subscribe_dto.AddSubscriber{
484+
Application: appId,
485+
})
486+
if err != nil {
487+
return err
488+
}
472489

473490
return i.docModule.SaveServiceDoc(ctx, info.Id, &service_dto.SaveServiceDoc{
474491
Doc: "The Translation API allows developers to translate text from one language to another. It supports multiple languages and enables easy integration of high-quality translation features into applications. With simple API requests, you can quickly translate content into different target languages.",

module/ai-balance/iml.go

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ type imlBalanceModule struct {
4343
transaction store.ITransaction `autowired:""`
4444
}
4545

46+
func (i *imlBalanceModule) SyncLocalBalances(ctx context.Context, address string) error {
47+
releases, err := i.getLocalBalances(ctx, address)
48+
if err != nil {
49+
return err
50+
}
51+
return i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
52+
}
53+
4654
func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Create) error {
4755
has, err := i.balanceService.Exist(ctx, input.Provider, input.Model)
4856
if err != nil {
@@ -63,6 +71,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
6371
}
6472
providerName := ""
6573
modelName := ""
74+
base := ""
6675
switch input.Type {
6776
case ai_balance_dto.ModelTypeOnline:
6877
p, has := model_runtime.GetProvider(input.Provider)
@@ -71,14 +80,16 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
7180
}
7281
providerName = p.Name()
7382
modelName = input.Model
83+
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
7484
case ai_balance_dto.ModelTypeLocal:
7585
input.Provider = "ollama"
7686
providerName = "Ollama"
7787
modelName = input.Model
78-
}
79-
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
80-
if !has {
81-
return fmt.Errorf("ollama address not found")
88+
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
89+
if !has {
90+
return fmt.Errorf("ollama address not found")
91+
}
92+
base = v
8293
}
8394

8495
return i.transaction.Transaction(ctx, func(ctx context.Context) error {
@@ -98,7 +109,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
98109
if err != nil {
99110
return err
100111
}
101-
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, v)}, true)
112+
return i.syncGateway(ctx, cluster.DefaultClusterID, []*gateway.DynamicRelease{newRelease(item, base)}, true)
102113
})
103114

104115
}
@@ -143,7 +154,16 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
143154
}
144155
releases := make([]*gateway.DynamicRelease, 0, len(list))
145156
for _, item := range list {
146-
releases = append(releases, newRelease(item, v))
157+
base := v
158+
if item.Provider != "ollama" {
159+
p, has := model_runtime.GetProvider(item.Provider)
160+
if !has {
161+
continue
162+
}
163+
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
164+
}
165+
166+
releases = append(releases, newRelease(item, base))
147167
}
148168
err = i.syncGateway(ctx, cluster.DefaultClusterID, releases, true)
149169
if err != nil {
@@ -237,3 +257,73 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re
237257

238258
return nil
239259
}
260+
261+
func (i *imlBalanceModule) getLocalBalances(ctx context.Context, v string) ([]*gateway.DynamicRelease, error) {
262+
balances, err := i.balanceService.Search(ctx, "", map[string]interface{}{"provider": "ollama"}, "priority asc")
263+
if err != nil {
264+
return nil, err
265+
}
266+
if v == "" {
267+
var has bool
268+
v, has = i.settingService.Get(ctx, "system.ai_model.ollama_address")
269+
if !has {
270+
return nil, fmt.Errorf("ollama address not found")
271+
}
272+
}
273+
274+
releases := make([]*gateway.DynamicRelease, 0, len(balances))
275+
for _, item := range balances {
276+
base := v
277+
if item.Provider != "ollama" {
278+
p, has := model_runtime.GetProvider(item.Provider)
279+
if !has {
280+
continue
281+
}
282+
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
283+
}
284+
releases = append(releases, newRelease(item, base))
285+
}
286+
return releases, nil
287+
}
288+
289+
func (i *imlBalanceModule) getBalances(ctx context.Context) ([]*gateway.DynamicRelease, error) {
290+
balances, err := i.balanceService.Search(ctx, "", nil, "priority asc")
291+
if err != nil {
292+
return nil, err
293+
}
294+
v, has := i.settingService.Get(ctx, "system.ai_model.ollama_address")
295+
if !has {
296+
return nil, fmt.Errorf("ollama address not found")
297+
}
298+
releases := make([]*gateway.DynamicRelease, 0, len(balances))
299+
for _, item := range balances {
300+
base := v
301+
if item.Provider != "ollama" {
302+
p, has := model_runtime.GetProvider(item.Provider)
303+
if !has {
304+
continue
305+
}
306+
base = fmt.Sprintf("%s://%s", p.URI().Scheme(), p.URI().Host())
307+
}
308+
releases = append(releases, newRelease(item, base))
309+
}
310+
return releases, nil
311+
}
312+
313+
func (i *imlBalanceModule) initGateway(ctx context.Context, clusterId string, clientDriver gateway.IClientDriver) error {
314+
releases, err := i.getBalances(ctx)
315+
if err != nil {
316+
return err
317+
}
318+
for _, p := range releases {
319+
client, err := clientDriver.Dynamic(p.Resource)
320+
if err != nil {
321+
return err
322+
}
323+
err = client.Online(ctx, p)
324+
if err != nil {
325+
return err
326+
}
327+
}
328+
return nil
329+
}

module/ai-balance/module.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"reflect"
66

7+
"github.com/APIParkLab/APIPark/gateway"
8+
79
"github.com/eolinker/go-common/autowire"
810

911
ai_balance_dto "github.com/APIParkLab/APIPark/module/ai-balance/dto"
@@ -14,10 +16,13 @@ type IBalanceModule interface {
1416
Sort(ctx context.Context, input *ai_balance_dto.Sort) error
1517
List(ctx context.Context, keyword string) ([]*ai_balance_dto.Item, error)
1618
Delete(ctx context.Context, id string) error
19+
SyncLocalBalances(ctx context.Context, address string) error
1720
}
1821

1922
func init() {
23+
balanceModule := new(imlBalanceModule)
2024
autowire.Auto[IBalanceModule](func() reflect.Value {
21-
return reflect.ValueOf(new(imlBalanceModule))
25+
gateway.RegisterInitHandleFunc(balanceModule.initGateway)
26+
return reflect.ValueOf(balanceModule)
2227
})
2328
}

0 commit comments

Comments
 (0)