Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion aigateway/component/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type OpenAIComponent interface {
ListModels(c context.Context, user string, req types.ListModelsReq) (types.ModelList, error)
GetModelByID(c context.Context, username, modelID string) (*types.Model, error)
RecordUsage(c context.Context, userUUID string, model *types.Model, tokenCounter token.Counter, sceneValue string) error
CheckBalance(ctx context.Context, username string) error
CheckBalance(ctx context.Context, username, userUUID string) error
}

type openaiComponentImpl struct {
Expand Down
2 changes: 1 addition & 1 deletion aigateway/component/openai_ce.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ func parseScene(sceneValue string) common_types.SceneType {
return common_types.SceneModelServerless
}

func (e *extendOpenai) CheckBalance(ctx context.Context, username string) error {
func (e *extendOpenai) CheckBalance(ctx context.Context, username, userUUID string) error {
return nil
}
6 changes: 3 additions & 3 deletions aigateway/handler/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {

sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey)
// Check balance before processing request
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username); err != nil {
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, userUUID); err != nil {
h.handleInsufficientBalance(c, chatReq.Stream, username, modelID, err)
return
}
Expand Down Expand Up @@ -510,7 +510,7 @@ func (h *OpenAIHandlerImpl) GenerateImage(c *gin.Context) {
}

sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey)
if err := h.openaiComponent.CheckBalance(ctx, username); err != nil {
if err := h.openaiComponent.CheckBalance(ctx, username, userUUID); err != nil {
h.handleInsufficientBalance(c, false, username, modelID, err)
return
}
Expand Down Expand Up @@ -690,7 +690,7 @@ func (h *OpenAIHandlerImpl) Embedding(c *gin.Context) {

sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey)
// Check balance before processing request
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username); err != nil {
if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, userUUID); err != nil {
h.handleInsufficientBalance(c, false, username, modelID, err)
return
}
Expand Down
16 changes: 8 additions & 8 deletions aigateway/handler/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
ClusterID: "test-cls",
}, nil)
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
expectReq := ChatCompletionRequest{}
_ = json.Unmarshal(body, &expectReq)
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
Expand Down Expand Up @@ -438,7 +438,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
ClusterID: "test-cls",
}, nil)
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
expectReq := ChatCompletionRequest{}
_ = json.Unmarshal(body, &expectReq)
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
ClusterID: "test-cls",
}, nil)
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
expectReq := ChatCompletionRequest{}
_ = json.Unmarshal(body, &expectReq)
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
Expand Down Expand Up @@ -545,7 +545,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
ClusterID: "test-cls",
}, nil)
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1:svc1").Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
expectReq := ChatCompletionRequest{}
_ = json.Unmarshal(body, &expectReq)
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
Expand Down Expand Up @@ -600,7 +600,7 @@ func TestOpenAIHandler_Chat(t *testing.T) {
Endpoint: testServer.URL,
}
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "external-model-id").Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
expectReq := ChatCompletionRequest{}
_ = json.Unmarshal(body, &expectReq)
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID).
Expand Down Expand Up @@ -802,7 +802,7 @@ func TestOpenAIHandler_Embedding(t *testing.T) {
Return(tokenCounter).Once()
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "model1").
Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
tester.mocks.openAIComp.EXPECT().RecordUsage(mock.Anything, "testuuid", model, mock.Anything, mock.Anything).RunAndReturn(
func(ctx context.Context, userID string, model *types.Model, counter token.Counter, sceneValue string) error {
wg.Done()
Expand Down Expand Up @@ -939,7 +939,7 @@ func TestOpenAIHandler_GenerateImage(t *testing.T) {
Endpoint: "https://api.example.com/images/generations",
}
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "test-model").Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
tester.mocks.moderationComp.EXPECT().CheckImagePrompts(mock.Anything, "sensitive prompt", "testuuid").Return(&rpc.CheckResult{IsSensitive: true}, nil)

tester.handler.GenerateImage(c)
Expand Down Expand Up @@ -972,7 +972,7 @@ func TestOpenAIHandler_GenerateImage(t *testing.T) {
Endpoint: "https://api.example.com/images/generations",
}
tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "test-model").Return(model, nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser").Return(nil)
tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil)
tester.mocks.moderationComp.EXPECT().CheckImagePrompts(mock.Anything, "test prompt", "testuuid").Return(nil, errors.New("moderation service error"))

tester.handler.GenerateImage(c)
Expand Down
2 changes: 2 additions & 0 deletions builder/store/database/agent_knowledge_base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ func TestAgentKnowledgeBaseStore_List_OrderByUpdatedAt(t *testing.T) {

// Create knowledge bases
kb1 := &database.AgentKnowledgeBase{
ID: 1,
UserUUID: userUUID,
Name: "First Knowledge Base",
Description: "First description",
Expand All @@ -398,6 +399,7 @@ func TestAgentKnowledgeBaseStore_List_OrderByUpdatedAt(t *testing.T) {
require.NoError(t, err)

kb2 := &database.AgentKnowledgeBase{
ID: 2,
UserUUID: userUUID,
Name: "Second Knowledge Base",
Description: "Second description",
Expand Down
2 changes: 1 addition & 1 deletion common/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ type Config struct {
ExpiredPresentCronExpression string `env:"OPENCSG_ACCOUNTING_EXPIRED_PRESENT_CRON_EXPRESSION" default:"0 0 * * *"`
ThresholdOfStopDeploy int `env:"OPENCSG_ACCOUNTING_THRESHOLD_OF_STOP_DEPLOY" default:"5000"`
ThresholdOfStopLLMInference int `env:"OPENCSG_ACCOUNTING_THRESHOLD_OF_STOP_LLM_INFERENCE" default:"5000"`
LLMBalanceCheckCacheTTL int `env:"OPENCSG_ACCOUNTING_LLM_BALANCE_CHECK_CACHE_TTL" default:"86400"`
LLMBalanceCheckCacheTTL int `env:"OPENCSG_ACCOUNTING_LLM_BALANCE_CHECK_CACHE_TTL" default:"120"`
}

User struct {
Expand Down
Loading