Skip to content

Commit d17151f

Browse files
csg-pr-botDev Agent
authored andcommitted
use model repo_path as model id when invoking internal llm endpoint (#966)
Co-authored-by: Dev Agent <dev-agent@example.com>
1 parent ff03aae commit d17151f

File tree

2 files changed

+81
-30
lines changed

2 files changed

+81
-30
lines changed

aigateway/handler/openai.go

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
293293
}
294294
target := ""
295295
host := ""
296+
modelName := ""
296297
if len(model.SvcName) > 0 {
297298
cluster, errCls := h.clusterComp.GetClusterByID(c, targetReq.ClusterID)
298299
if errCls != nil {
@@ -305,9 +306,11 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
305306
return
306307
}
307308
target, host, _ = common.ExtractDeployTargetAndHost(c.Request.Context(), cluster, targetReq)
309+
modelName = model.CSGHubModelID
308310
} else {
309311
slog.DebugContext(c.Request.Context(), "external model", slog.Any("model", model))
310312
target = model.Endpoint
313+
modelName = model.ID
311314
}
312315
if err != nil || len(target) < 1 {
313316
slog.ErrorContext(c.Request.Context(), "failed to get model target address", slog.Any("error", err),
@@ -322,12 +325,6 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
322325
return
323326
}
324327

325-
modelName, _, err := (component.ModelIDBuilder{}).From(modelID)
326-
if err != nil {
327-
slog.ErrorContext(c.Request.Context(), "failed to process chat request", "error", err, "model_id", modelID)
328-
c.String(http.StatusBadRequest, err.Error())
329-
return
330-
}
331328
chatReq.Model = modelName
332329
if chatReq.Stream {
333330
c.Writer.Header().Set("Content-Type", "text/event-stream")
@@ -478,6 +475,7 @@ func (h *OpenAIHandlerImpl) GenerateImage(c *gin.Context) {
478475
}
479476
target := ""
480477
host := ""
478+
modelName := ""
481479
if len(model.SvcName) > 0 {
482480
cluster, errCls := h.clusterComp.GetClusterByID(c, targetReq.ClusterID)
483481
if errCls != nil {
@@ -490,8 +488,10 @@ func (h *OpenAIHandlerImpl) GenerateImage(c *gin.Context) {
490488
return
491489
}
492490
target, host, _ = common.ExtractDeployTargetAndHost(ctx, cluster, targetReq)
491+
modelName = model.CSGHubModelID
493492
} else {
494493
target = model.Endpoint
494+
modelName = model.ID
495495
}
496496
if err != nil || len(target) < 1 {
497497
slog.ErrorContext(ctx, "failed to get model target address", slog.Any("error", err), slog.String("model_id", modelID))
@@ -501,14 +501,6 @@ func (h *OpenAIHandlerImpl) GenerateImage(c *gin.Context) {
501501
return
502502
}
503503

504-
modelName, _, err := (component.ModelIDBuilder{}).From(modelID)
505-
if err != nil {
506-
c.JSON(http.StatusBadRequest, gin.H{"error": types.Error{
507-
Code: "invalid_model_id", Message: "invalid model ID: " + err.Error(), Type: "invalid_request_error",
508-
}})
509-
return
510-
}
511-
512504
adapter := h.t2iRegistry.GetAdapter(model)
513505
if adapter == nil {
514506
c.JSON(http.StatusBadRequest, gin.H{"error": types.Error{
@@ -665,6 +657,7 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
665657
}
666658
target := ""
667659
host := ""
660+
modelName := ""
668661
if len(model.SvcName) > 0 {
669662
cluster, errCls := h.clusterComp.GetClusterByID(c, targetReq.ClusterID)
670663
if errCls != nil {
@@ -677,8 +670,10 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
677670
return
678671
}
679672
target, host, _ = common.ExtractDeployTargetAndHost(c.Request.Context(), cluster, targetReq)
673+
modelName = model.CSGHubModelID
680674
} else {
681675
target = model.Endpoint
676+
modelName = model.ID
682677
}
683678
if err != nil || len(target) < 1 {
684679
slog.ErrorContext(c, "failed to get embedding target address", slog.Any("error", err),
@@ -700,12 +695,6 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
700695
return
701696
}
702697

703-
modelName, _, err := (component.ModelIDBuilder{}).From(modelID)
704-
if err != nil {
705-
slog.ErrorContext(c, "failed to process chat request", "error", err, "model_id", modelID)
706-
c.String(http.StatusBadRequest, err.Error())
707-
return
708-
}
709698
req.Model = modelName
710699
data, _ := json.Marshal(req)
711700
c.Request.Body = io.NopCloser(bytes.NewReader(data))

aigateway/handler/openai_test.go

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -350,8 +350,9 @@ func TestOpenAIHandler_Chat(t *testing.T) {
350350
OwnedBy: "testuser",
351351
},
352352
InternalModelInfo: types.InternalModelInfo{
353-
ClusterID: "test-cls",
354-
SvcName: "test-svc",
353+
ClusterID: "test-cls",
354+
SvcName: "test-svc",
355+
CSGHubModelID: "model1",
355356
},
356357
}
357358
tester.mocks.mockClsComp.EXPECT().GetClusterByID(mock.Anything, "test-cls").Return(&database.ClusterInfo{
@@ -382,8 +383,9 @@ func TestOpenAIHandler_Chat(t *testing.T) {
382383
OwnedBy: "testuser",
383384
},
384385
InternalModelInfo: types.InternalModelInfo{
385-
ClusterID: "test-cls",
386-
SvcName: "test-svc",
386+
ClusterID: "test-cls",
387+
SvcName: "test-svc",
388+
CSGHubModelID: "model1",
387389
},
388390
Endpoint: "test-endpoint",
389391
}
@@ -426,8 +428,9 @@ func TestOpenAIHandler_Chat(t *testing.T) {
426428
OwnedBy: "testuser",
427429
},
428430
InternalModelInfo: types.InternalModelInfo{
429-
ClusterID: "test-cls",
430-
SvcName: "test-svc",
431+
ClusterID: "test-cls",
432+
SvcName: "test-svc",
433+
CSGHubModelID: "model1",
431434
},
432435
Endpoint: testServer.URL,
433436
}
@@ -470,8 +473,9 @@ func TestOpenAIHandler_Chat(t *testing.T) {
470473
OwnedBy: "testuser",
471474
},
472475
InternalModelInfo: types.InternalModelInfo{
473-
ClusterID: "test-cls",
474-
SvcName: "test-svc",
476+
ClusterID: "test-cls",
477+
SvcName: "test-svc",
478+
CSGHubModelID: "model1",
475479
},
476480
Endpoint: testServer.URL,
477481
}
@@ -531,8 +535,9 @@ func TestOpenAIHandler_Chat(t *testing.T) {
531535
OwnedBy: "testuser",
532536
},
533537
InternalModelInfo: types.InternalModelInfo{
534-
ClusterID: "test-cls",
535-
SvcName: "test-svc",
538+
ClusterID: "test-cls",
539+
SvcName: "test-svc",
540+
CSGHubModelID: "model1",
536541
},
537542
Endpoint: testServer.URL,
538543
}
@@ -566,6 +571,63 @@ func TestOpenAIHandler_Chat(t *testing.T) {
566571
wg.Wait()
567572
assert.Equal(t, http.StatusOK, w.Code)
568573
})
574+
t.Run("external model uses model id as request model", func(t *testing.T) {
575+
tester, c, w := setupTest(t)
576+
chatReq := ChatCompletionRequest{
577+
Model: "external-model-id",
578+
Messages: []openai.ChatCompletionMessageParamUnion{
579+
openai.UserMessage("Hello"),
580+
},
581+
}
582+
body, _ := json.Marshal(chatReq)
583+
c.Request.Method = http.MethodPost
584+
c.Request.Body = io.NopCloser(bytes.NewReader(body))
585+
586+
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
587+
w.WriteHeader(http.StatusOK)
588+
}))
589+
defer testServer.Close()
590+
591+
model := &types.Model{
592+
BaseModel: types.BaseModel{
593+
ID: "external-model-id",
594+
Object: "model",
595+
OwnedBy: "testuser",
596+
},
597+
InternalModelInfo: types.InternalModelInfo{
598+
SvcName: "",
599+
},
600+
Endpoint: testServer.URL,
601+
}
602+
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "external-model-id").Return(model, nil)
603+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
604+
expectReq := ChatCompletionRequest{}
605+
_ = json.Unmarshal(body, &expectReq)
606+
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
607+
Return(&rpc.CheckResult{IsSensitive: false}, nil)
608+
llmTokenCounter := mocktoken.NewMockChatTokenCounter(t)
609+
tester.mocks.tokenCounterFactory.EXPECT().NewChat(
610+
token.CreateParam{
611+
Endpoint: model.Endpoint,
612+
Host: "",
613+
Model: model.ID,
614+
ImageID: model.ImageID,
615+
}).
616+
Return(llmTokenCounter)
617+
llmTokenCounter.EXPECT().AppendPrompts(expectReq.Messages).Return()
618+
619+
var wg sync.WaitGroup
620+
wg.Add(1)
621+
tester.mocks.openAIComp.EXPECT().RecordUsage(mock.Anything, "testuuid", model, llmTokenCounter, mock.Anything).
622+
RunAndReturn(func(ctx context.Context, uuid string, model *types.Model, counter token.Counter, sceneValue string) error {
623+
wg.Done()
624+
return nil
625+
})
626+
627+
tester.handler.Chat(c)
628+
wg.Wait()
629+
assert.Equal(t, http.StatusOK, w.Code)
630+
})
569631
}
570632

571633
func TestOpenAIHandler_Embedding(t *testing.T) {

0 commit comments

Comments
 (0)