@@ -43,6 +43,14 @@ type imlBalanceModule struct {
4343 transaction store.ITransaction `autowired:""`
4444}
4545
46+ func (i * imlBalanceModule ) SyncLocalBalances (ctx context.Context , address string ) error {
47+ releases , err := i .getLocalBalances (ctx , address )
48+ if err != nil {
49+ return err
50+ }
51+ return i .syncGateway (ctx , cluster .DefaultClusterID , releases , true )
52+ }
53+
4654func (i * imlBalanceModule ) Create (ctx context.Context , input * ai_balance_dto.Create ) error {
4755 has , err := i .balanceService .Exist (ctx , input .Provider , input .Model )
4856 if err != nil {
@@ -63,6 +71,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
6371 }
6472 providerName := ""
6573 modelName := ""
74+ base := ""
6675 switch input .Type {
6776 case ai_balance_dto .ModelTypeOnline :
6877 p , has := model_runtime .GetProvider (input .Provider )
@@ -71,14 +80,16 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
7180 }
7281 providerName = p .Name ()
7382 modelName = input .Model
83+ base = fmt .Sprintf ("%s://%s" , p .URI ().Scheme (), p .URI ().Host ())
7484 case ai_balance_dto .ModelTypeLocal :
7585 input .Provider = "ollama"
7686 providerName = "Ollama"
7787 modelName = input .Model
78- }
79- v , has := i .settingService .Get (ctx , "system.ai_model.ollama_address" )
80- if ! has {
81- return fmt .Errorf ("ollama address not found" )
88+ v , has := i .settingService .Get (ctx , "system.ai_model.ollama_address" )
89+ if ! has {
90+ return fmt .Errorf ("ollama address not found" )
91+ }
92+ base = v
8293 }
8394
8495 return i .transaction .Transaction (ctx , func (ctx context.Context ) error {
@@ -98,7 +109,7 @@ func (i *imlBalanceModule) Create(ctx context.Context, input *ai_balance_dto.Cre
98109 if err != nil {
99110 return err
100111 }
101- return i .syncGateway (ctx , cluster .DefaultClusterID , []* gateway.DynamicRelease {newRelease (item , v )}, true )
112+ return i .syncGateway (ctx , cluster .DefaultClusterID , []* gateway.DynamicRelease {newRelease (item , base )}, true )
102113 })
103114
104115}
@@ -143,7 +154,16 @@ func (i *imlBalanceModule) Sort(ctx context.Context, input *ai_balance_dto.Sort)
143154 }
144155 releases := make ([]* gateway.DynamicRelease , 0 , len (list ))
145156 for _ , item := range list {
146- releases = append (releases , newRelease (item , v ))
157+ base := v
158+ if item .Provider != "ollama" {
159+ p , has := model_runtime .GetProvider (item .Provider )
160+ if ! has {
161+ continue
162+ }
163+ base = fmt .Sprintf ("%s://%s" , p .URI ().Scheme (), p .URI ().Host ())
164+ }
165+
166+ releases = append (releases , newRelease (item , base ))
147167 }
148168 err = i .syncGateway (ctx , cluster .DefaultClusterID , releases , true )
149169 if err != nil {
@@ -237,3 +257,73 @@ func (i *imlBalanceModule) syncGateway(ctx context.Context, clusterId string, re
237257
238258 return nil
239259}
260+
261+ func (i * imlBalanceModule ) getLocalBalances (ctx context.Context , v string ) ([]* gateway.DynamicRelease , error ) {
262+ balances , err := i .balanceService .Search (ctx , "" , map [string ]interface {}{"provider" : "ollama" }, "priority asc" )
263+ if err != nil {
264+ return nil , err
265+ }
266+ if v == "" {
267+ var has bool
268+ v , has = i .settingService .Get (ctx , "system.ai_model.ollama_address" )
269+ if ! has {
270+ return nil , fmt .Errorf ("ollama address not found" )
271+ }
272+ }
273+
274+ releases := make ([]* gateway.DynamicRelease , 0 , len (balances ))
275+ for _ , item := range balances {
276+ base := v
277+ if item .Provider != "ollama" {
278+ p , has := model_runtime .GetProvider (item .Provider )
279+ if ! has {
280+ continue
281+ }
282+ base = fmt .Sprintf ("%s://%s" , p .URI ().Scheme (), p .URI ().Host ())
283+ }
284+ releases = append (releases , newRelease (item , base ))
285+ }
286+ return releases , nil
287+ }
288+
289+ func (i * imlBalanceModule ) getBalances (ctx context.Context ) ([]* gateway.DynamicRelease , error ) {
290+ balances , err := i .balanceService .Search (ctx , "" , nil , "priority asc" )
291+ if err != nil {
292+ return nil , err
293+ }
294+ v , has := i .settingService .Get (ctx , "system.ai_model.ollama_address" )
295+ if ! has {
296+ return nil , fmt .Errorf ("ollama address not found" )
297+ }
298+ releases := make ([]* gateway.DynamicRelease , 0 , len (balances ))
299+ for _ , item := range balances {
300+ base := v
301+ if item .Provider != "ollama" {
302+ p , has := model_runtime .GetProvider (item .Provider )
303+ if ! has {
304+ continue
305+ }
306+ base = fmt .Sprintf ("%s://%s" , p .URI ().Scheme (), p .URI ().Host ())
307+ }
308+ releases = append (releases , newRelease (item , base ))
309+ }
310+ return releases , nil
311+ }
312+
313+ func (i * imlBalanceModule ) initGateway (ctx context.Context , clusterId string , clientDriver gateway.IClientDriver ) error {
314+ releases , err := i .getBalances (ctx )
315+ if err != nil {
316+ return err
317+ }
318+ for _ , p := range releases {
319+ client , err := clientDriver .Dynamic (p .Resource )
320+ if err != nil {
321+ return err
322+ }
323+ err = client .Online (ctx , p )
324+ if err != nil {
325+ return err
326+ }
327+ }
328+ return nil
329+ }
0 commit comments