Skip to content

Commit 8eae3f9

Browse files
csg-pr-botDev Agent
andauthored
feat(aigateway): don't skip any balance check (#963)
Co-authored-by: Dev Agent <dev-agent@example.com>
1 parent 7cb4455 commit 8eae3f9

6 files changed

Lines changed: 23 additions & 25 deletions

File tree

AGENTS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Folders relative to the root of the repository for each service:
8888

8989
## Testing
9090

91-
- Use `make mock_gen GO_TAGS={go.buildTags}` to generate mock implementations for the interfaces.
91+
- Don't modify any code file under foler `_mocks`, use `make mock_gen GO_TAGS={go.buildTags}` to generate mock implementations for the interfaces.
9292
- Use `make test GO_TAGS={go.buildTags}` to run all tests in project.
9393
- Mock dependencies (e.g., database, RPC clients) using tools like `mockery`.
9494

_mocks/opencsg.com/csghub-server/aigateway/component/mock_OpenAIComponent.go

Lines changed: 10 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

aigateway/component/openai.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ type OpenAIComponent interface {
3030
ListModels(c context.Context, user string, req types.ListModelsReq) (types.ModelList, error)
3131
GetModelByID(c context.Context, username, modelID string) (*types.Model, error)
3232
RecordUsage(c context.Context, userUUID string, model *types.Model, tokenCounter token.Counter, sceneValue string) error
33-
CheckBalance(ctx context.Context, username string, model *types.Model, sceneValue string) error
33+
CheckBalance(ctx context.Context, username string) error
3434
}
3535

3636
type openaiComponentImpl struct {

aigateway/component/openai_ce.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,6 @@ func parseScene(sceneValue string) common_types.SceneType {
4545
return common_types.SceneModelServerless
4646
}
4747

48-
func (e *extendOpenai) CheckBalance(ctx context.Context, username string, model *types.Model, sceneValue string) error {
48+
func (e *extendOpenai) CheckBalance(ctx context.Context, username string) error {
4949
return nil
5050
}

aigateway/handler/openai.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
340340

341341
sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey)
342342
// Check balance before processing request
343-
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, model, sceneValue); err != nil {
343+
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username); err != nil {
344344
h.handleInsufficientBalance(c, chatReq.Stream, username, modelID, err)
345345
return
346346
}
@@ -518,7 +518,7 @@ func (h *OpenAIHandlerImpl) GenerateImage(c *gin.Context) {
518518
}
519519

520520
sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey)
521-
if err := h.openaiComponent.CheckBalance(ctx, username, model, sceneValue); err != nil {
521+
if err := h.openaiComponent.CheckBalance(ctx, username); err != nil {
522522
h.handleInsufficientBalance(c, false, username, modelID, err)
523523
return
524524
}
@@ -695,7 +695,7 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {
695695

696696
sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey)
697697
// Check balance before processing request
698-
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, model, sceneValue); err != nil {
698+
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username); err != nil {
699699
h.handleInsufficientBalance(c, false, username, modelID, err)
700700
return
701701
}

aigateway/handler/openai_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
391391
ClusterID: "test-cls",
392392
}, nil)
393393
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
394-
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", model, "").Return(nil)
394+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
395395
expectReq := ChatCompletionRequest{}
396396
_ = json.Unmarshal(body, &expectReq)
397397
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
@@ -435,7 +435,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
435435
ClusterID: "test-cls",
436436
}, nil)
437437
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
438-
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", model, "").Return(nil)
438+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
439439
expectReq := ChatCompletionRequest{}
440440
_ = json.Unmarshal(body, &expectReq)
441441
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
@@ -479,7 +479,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
479479
ClusterID: "test-cls",
480480
}, nil)
481481
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
482-
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", model, "").Return(nil)
482+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
483483
expectReq := ChatCompletionRequest{}
484484
_ = json.Unmarshal(body, &expectReq)
485485
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
@@ -540,7 +540,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
540540
ClusterID: "test-cls",
541541
}, nil)
542542
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
543-
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", model, "").Return(nil)
543+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
544544
expectReq := ChatCompletionRequest{}
545545
_ = json.Unmarshal(body, &expectReq)
546546
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
@@ -740,7 +740,7 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
740740
Return(tokenCounter).Once()
741741
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1").
742742
Return(model, nil)
743-
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", model, "").Return(nil)
743+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
744744
tester.mocks.openAIComp.EXPECT().RecordUsage(mock.Anything, "testuuid", model, mock.Anything, mock.Anything).RunAndReturn(
745745
func(ctx context.Context, userID string, model *types.Model, counter token.Counter, sceneValue string) error {
746746
wg.Done()
@@ -877,7 +877,7 @@ func TestOpenAIHandler_GenerateImage(t *testing.T) {
877877
Endpoint: "https://api.example.com/images/generations",
878878
}
879879
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "test-model").Return(model, nil)
880-
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", model, mock.Anything).Return(nil)
880+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
881881
tester.mocks.moderationComp.EXPECT().CheckImagePrompts(mock.Anything, "sensitive prompt", "testuuid").Return(&rpc.CheckResult{IsSensitive: true}, nil)
882882

883883
tester.handler.GenerateImage(c)
@@ -910,7 +910,7 @@ func TestOpenAIHandler_GenerateImage(t *testing.T) {
910910
Endpoint: "https://api.example.com/images/generations",
911911
}
912912
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "test-model").Return(model, nil)
913-
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", model, mock.Anything).Return(nil)
913+
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
914914
tester.mocks.moderationComp.EXPECT().CheckImagePrompts(mock.Anything, "test prompt", "testuuid").Return(nil, errors.New("moderation service error"))
915915

916916
tester.handler.GenerateImage(c)

0 commit comments

Comments
 (0)