diff --git a/_mocks/opencsg.com/csghub-server/aigateway/component/mock_Moderation.go b/_mocks/opencsg.com/csghub-server/aigateway/component/mock_Moderation.go index 0e71cd7da..e263d769d 100644 --- a/_mocks/opencsg.com/csghub-server/aigateway/component/mock_Moderation.go +++ b/_mocks/opencsg.com/csghub-server/aigateway/component/mock_Moderation.go @@ -85,9 +85,9 @@ func (_c *MockModeration_CheckChatNonStreamResponse_Call) RunAndReturn(run func( return _c } -// CheckChatPrompts provides a mock function with given fields: ctx, messages, uuid -func (_m *MockModeration) CheckChatPrompts(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string) (*rpc.CheckResult, error) { - ret := _m.Called(ctx, messages, uuid) +// CheckChatPrompts provides a mock function with given fields: ctx, messages, uuid, isStream +func (_m *MockModeration) CheckChatPrompts(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string, isStream bool) (*rpc.CheckResult, error) { + ret := _m.Called(ctx, messages, uuid, isStream) if len(ret) == 0 { panic("no return value specified for CheckChatPrompts") @@ -95,19 +95,19 @@ func (_m *MockModeration) CheckChatPrompts(ctx context.Context, messages []opena var r0 *rpc.CheckResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, []openai.ChatCompletionMessageParamUnion, string) (*rpc.CheckResult, error)); ok { - return rf(ctx, messages, uuid) + if rf, ok := ret.Get(0).(func(context.Context, []openai.ChatCompletionMessageParamUnion, string, bool) (*rpc.CheckResult, error)); ok { + return rf(ctx, messages, uuid, isStream) } - if rf, ok := ret.Get(0).(func(context.Context, []openai.ChatCompletionMessageParamUnion, string) *rpc.CheckResult); ok { - r0 = rf(ctx, messages, uuid) + if rf, ok := ret.Get(0).(func(context.Context, []openai.ChatCompletionMessageParamUnion, string, bool) *rpc.CheckResult); ok { + r0 = rf(ctx, messages, uuid, isStream) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rpc.CheckResult) } } - if rf, ok := ret.Get(1).(func(context.Context, []openai.ChatCompletionMessageParamUnion, string) error); ok { - r1 = rf(ctx, messages, uuid) + if rf, ok := ret.Get(1).(func(context.Context, []openai.ChatCompletionMessageParamUnion, string, bool) error); ok { + r1 = rf(ctx, messages, uuid, isStream) } else { r1 = ret.Error(1) } @@ -124,13 +124,14 @@ type MockModeration_CheckChatPrompts_Call struct { // - ctx context.Context // - messages []openai.ChatCompletionMessageParamUnion // - uuid string -func (_e *MockModeration_Expecter) CheckChatPrompts(ctx interface{}, messages interface{}, uuid interface{}) *MockModeration_CheckChatPrompts_Call { - return &MockModeration_CheckChatPrompts_Call{Call: _e.mock.On("CheckChatPrompts", ctx, messages, uuid)} +// - isStream bool +func (_e *MockModeration_Expecter) CheckChatPrompts(ctx interface{}, messages interface{}, uuid interface{}, isStream interface{}) *MockModeration_CheckChatPrompts_Call { + return &MockModeration_CheckChatPrompts_Call{Call: _e.mock.On("CheckChatPrompts", ctx, messages, uuid, isStream)} } -func (_c *MockModeration_CheckChatPrompts_Call) Run(run func(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string)) *MockModeration_CheckChatPrompts_Call { +func (_c *MockModeration_CheckChatPrompts_Call) Run(run func(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string, isStream bool)) *MockModeration_CheckChatPrompts_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].([]openai.ChatCompletionMessageParamUnion), args[2].(string)) + run(args[0].(context.Context), args[1].([]openai.ChatCompletionMessageParamUnion), args[2].(string), args[3].(bool)) }) return _c } @@ -140,7 +141,7 @@ func (_c *MockModeration_CheckChatPrompts_Call) Return(_a0 *rpc.CheckResult, _a1 return _c } -func (_c *MockModeration_CheckChatPrompts_Call) RunAndReturn(run func(context.Context, []openai.ChatCompletionMessageParamUnion, string) (*rpc.CheckResult, error)) *MockModeration_CheckChatPrompts_Call { +func (_c *MockModeration_CheckChatPrompts_Call) RunAndReturn(run func(context.Context, []openai.ChatCompletionMessageParamUnion, string, bool) (*rpc.CheckResult, error)) *MockModeration_CheckChatPrompts_Call { _c.Call.Return(run) return _c } @@ -324,6 +325,65 @@ func (_c *MockModeration_CheckImagePrompts_Call) RunAndReturn(run func(context.C return _c } +// CloseStreamCheck provides a mock function with given fields: ctx, uuid +func (_m *MockModeration) CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) { + ret := _m.Called(ctx, uuid) + + if len(ret) == 0 { + panic("no return value specified for CloseStreamCheck") + } + + var r0 *rpc.CheckResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*rpc.CheckResult, error)); ok { + return rf(ctx, uuid) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *rpc.CheckResult); ok { + r0 = rf(ctx, uuid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rpc.CheckResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, uuid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockModeration_CloseStreamCheck_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseStreamCheck' +type MockModeration_CloseStreamCheck_Call struct { + *mock.Call +} + +// CloseStreamCheck is a helper method to define mock.On call +// - ctx context.Context +// - uuid string +func (_e *MockModeration_Expecter) CloseStreamCheck(ctx interface{}, uuid interface{}) *MockModeration_CloseStreamCheck_Call { + return &MockModeration_CloseStreamCheck_Call{Call: _e.mock.On("CloseStreamCheck", ctx, uuid)} +} + +func (_c *MockModeration_CloseStreamCheck_Call) Run(run func(ctx context.Context, uuid string)) *MockModeration_CloseStreamCheck_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockModeration_CloseStreamCheck_Call) Return(_a0 *rpc.CheckResult, _a1 error) *MockModeration_CloseStreamCheck_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockModeration_CloseStreamCheck_Call) RunAndReturn(run func(context.Context, string) (*rpc.CheckResult, error)) *MockModeration_CloseStreamCheck_Call { + _c.Call.Return(run) + return _c +} + // NewMockModeration creates a new instance of MockModeration. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockModeration(t interface { diff --git a/_mocks/opencsg.com/csghub-server/aigateway/component/mock_StreamChecker.go b/_mocks/opencsg.com/csghub-server/aigateway/component/mock_StreamChecker.go new file mode 100644 index 000000000..473c43e93 --- /dev/null +++ b/_mocks/opencsg.com/csghub-server/aigateway/component/mock_StreamChecker.go @@ -0,0 +1,158 @@ +// Code generated by mockery v2.53.5. DO NOT EDIT. + +package component + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + rpc "opencsg.com/csghub-server/builder/rpc" + + types "opencsg.com/csghub-server/aigateway/types" +) + +// MockStreamChecker is an autogenerated mock type for the StreamChecker type +type MockStreamChecker struct { + mock.Mock +} + +type MockStreamChecker_Expecter struct { + mock *mock.Mock +} + +func (_m *MockStreamChecker) EXPECT() *MockStreamChecker_Expecter { + return &MockStreamChecker_Expecter{mock: &_m.Mock} +} + +// CheckChatStreamResponse provides a mock function with given fields: ctx, chunk, uuid +func (_m *MockStreamChecker) CheckChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, uuid string) (*rpc.CheckResult, error) { + ret := _m.Called(ctx, chunk, uuid) + + if len(ret) == 0 { + panic("no return value specified for CheckChatStreamResponse") + } + + var r0 *rpc.CheckResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, types.ChatCompletionChunk, string) (*rpc.CheckResult, error)); ok { + return rf(ctx, chunk, uuid) + } + if rf, ok := ret.Get(0).(func(context.Context, types.ChatCompletionChunk, string) *rpc.CheckResult); ok { + r0 = rf(ctx, chunk, uuid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rpc.CheckResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, types.ChatCompletionChunk, string) error); ok { + r1 = rf(ctx, chunk, uuid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamChecker_CheckChatStreamResponse_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CheckChatStreamResponse' +type MockStreamChecker_CheckChatStreamResponse_Call struct { + *mock.Call +} + +// CheckChatStreamResponse is a helper method to define mock.On call +// - ctx context.Context +// - chunk types.ChatCompletionChunk +// - uuid string +func (_e *MockStreamChecker_Expecter) CheckChatStreamResponse(ctx interface{}, chunk interface{}, uuid interface{}) *MockStreamChecker_CheckChatStreamResponse_Call { + return &MockStreamChecker_CheckChatStreamResponse_Call{Call: _e.mock.On("CheckChatStreamResponse", ctx, chunk, uuid)} +} + +func (_c *MockStreamChecker_CheckChatStreamResponse_Call) Run(run func(ctx context.Context, chunk types.ChatCompletionChunk, uuid string)) *MockStreamChecker_CheckChatStreamResponse_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(types.ChatCompletionChunk), args[2].(string)) + }) + return _c +} + +func (_c *MockStreamChecker_CheckChatStreamResponse_Call) Return(_a0 *rpc.CheckResult, _a1 error) *MockStreamChecker_CheckChatStreamResponse_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamChecker_CheckChatStreamResponse_Call) RunAndReturn(run func(context.Context, types.ChatCompletionChunk, string) (*rpc.CheckResult, error)) *MockStreamChecker_CheckChatStreamResponse_Call { + _c.Call.Return(run) + return _c +} + +// CloseStreamCheck provides a mock function with given fields: ctx, uuid +func (_m *MockStreamChecker) CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) { + ret := _m.Called(ctx, uuid) + + if len(ret) == 0 { + panic("no return value specified for CloseStreamCheck") + } + + var r0 *rpc.CheckResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*rpc.CheckResult, error)); ok { + return rf(ctx, uuid) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *rpc.CheckResult); ok { + r0 = rf(ctx, uuid) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*rpc.CheckResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, uuid) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockStreamChecker_CloseStreamCheck_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseStreamCheck' +type MockStreamChecker_CloseStreamCheck_Call struct { + *mock.Call +} + +// CloseStreamCheck is a helper method to define mock.On call +// - ctx context.Context +// - uuid string +func (_e *MockStreamChecker_Expecter) CloseStreamCheck(ctx interface{}, uuid interface{}) *MockStreamChecker_CloseStreamCheck_Call { + return &MockStreamChecker_CloseStreamCheck_Call{Call: _e.mock.On("CloseStreamCheck", ctx, uuid)} +} + +func (_c *MockStreamChecker_CloseStreamCheck_Call) Run(run func(ctx context.Context, uuid string)) *MockStreamChecker_CloseStreamCheck_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockStreamChecker_CloseStreamCheck_Call) Return(_a0 *rpc.CheckResult, _a1 error) *MockStreamChecker_CloseStreamCheck_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockStreamChecker_CloseStreamCheck_Call) RunAndReturn(run func(context.Context, string) (*rpc.CheckResult, error)) *MockStreamChecker_CloseStreamCheck_Call { + _c.Call.Return(run) + return _c +} + +// NewMockStreamChecker creates a new instance of MockStreamChecker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockStreamChecker(t interface { + mock.TestingT + Cleanup(func()) +}) *MockStreamChecker { + mock := &MockStreamChecker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/_mocks/opencsg.com/csghub-server/builder/rpc/mock_ModerationSvcClient.go b/_mocks/opencsg.com/csghub-server/builder/rpc/mock_ModerationSvcClient.go index a5df6af70..9a85aad3b 100644 --- a/_mocks/opencsg.com/csghub-server/builder/rpc/mock_ModerationSvcClient.go +++ b/_mocks/opencsg.com/csghub-server/builder/rpc/mock_ModerationSvcClient.go @@ -145,9 +145,9 @@ func (_c *MockModerationSvcClient_PassImageURLCheck_Call) RunAndReturn(run func( return _c } -// PassLLMPromptCheck provides a mock function with given fields: ctx, text, accountId -func (_m *MockModerationSvcClient) PassLLMPromptCheck(ctx context.Context, text string, accountId string) (*rpc.CheckResult, error) { - ret := _m.Called(ctx, text, accountId) +// PassLLMPromptCheck provides a mock function with given fields: ctx, req +func (_m *MockModerationSvcClient) PassLLMPromptCheck(ctx context.Context, req types.LLMCheckRequest) (*rpc.CheckResult, error) { + ret := _m.Called(ctx, req) if len(ret) == 0 { panic("no return value specified for PassLLMPromptCheck") @@ -155,19 +155,19 @@ func (_m *MockModerationSvcClient) PassLLMPromptCheck(ctx context.Context, text var r0 *rpc.CheckResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (*rpc.CheckResult, error)); ok { - return rf(ctx, text, accountId) + if rf, ok := ret.Get(0).(func(context.Context, types.LLMCheckRequest) (*rpc.CheckResult, error)); ok { + return rf(ctx, req) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) *rpc.CheckResult); ok { - r0 = rf(ctx, text, accountId) + if rf, ok := ret.Get(0).(func(context.Context, types.LLMCheckRequest) *rpc.CheckResult); ok { + r0 = rf(ctx, req) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rpc.CheckResult) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, text, accountId) + if rf, ok := ret.Get(1).(func(context.Context, types.LLMCheckRequest) error); ok { + r1 = rf(ctx, req) } else { r1 = ret.Error(1) } @@ -182,15 +182,14 @@ type MockModerationSvcClient_PassLLMPromptCheck_Call struct { // PassLLMPromptCheck is a helper method to define mock.On call // - ctx context.Context -// - text string -// - accountId string -func (_e *MockModerationSvcClient_Expecter) PassLLMPromptCheck(ctx interface{}, text interface{}, accountId interface{}) *MockModerationSvcClient_PassLLMPromptCheck_Call { - return &MockModerationSvcClient_PassLLMPromptCheck_Call{Call: _e.mock.On("PassLLMPromptCheck", ctx, text, accountId)} +// - req types.LLMCheckRequest +func (_e *MockModerationSvcClient_Expecter) PassLLMPromptCheck(ctx interface{}, req interface{}) *MockModerationSvcClient_PassLLMPromptCheck_Call { + return &MockModerationSvcClient_PassLLMPromptCheck_Call{Call: _e.mock.On("PassLLMPromptCheck", ctx, req)} } -func (_c *MockModerationSvcClient_PassLLMPromptCheck_Call) Run(run func(ctx context.Context, text string, accountId string)) *MockModerationSvcClient_PassLLMPromptCheck_Call { +func (_c *MockModerationSvcClient_PassLLMPromptCheck_Call) Run(run func(ctx context.Context, req types.LLMCheckRequest)) *MockModerationSvcClient_PassLLMPromptCheck_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(types.LLMCheckRequest)) }) return _c } @@ -200,14 +199,14 @@ func (_c *MockModerationSvcClient_PassLLMPromptCheck_Call) Return(_a0 *rpc.Check return _c } -func (_c *MockModerationSvcClient_PassLLMPromptCheck_Call) RunAndReturn(run func(context.Context, string, string) (*rpc.CheckResult, error)) *MockModerationSvcClient_PassLLMPromptCheck_Call { +func (_c *MockModerationSvcClient_PassLLMPromptCheck_Call) RunAndReturn(run func(context.Context, types.LLMCheckRequest) (*rpc.CheckResult, error)) *MockModerationSvcClient_PassLLMPromptCheck_Call { _c.Call.Return(run) return _c } -// PassLLMRespCheck provides a mock function with given fields: ctx, text, sessionId -func (_m *MockModerationSvcClient) PassLLMRespCheck(ctx context.Context, text string, sessionId string) (*rpc.CheckResult, error) { - ret := _m.Called(ctx, text, sessionId) +// PassLLMRespCheck provides a mock function with given fields: ctx, req +func (_m *MockModerationSvcClient) PassLLMRespCheck(ctx context.Context, req types.LLMCheckRequest) (*rpc.CheckResult, error) { + ret := _m.Called(ctx, req) if len(ret) == 0 { panic("no return value specified for PassLLMRespCheck") @@ -215,19 +214,19 @@ func (_m *MockModerationSvcClient) PassLLMRespCheck(ctx context.Context, text st var r0 *rpc.CheckResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (*rpc.CheckResult, error)); ok { - return rf(ctx, text, sessionId) + if rf, ok := ret.Get(0).(func(context.Context, types.LLMCheckRequest) (*rpc.CheckResult, error)); ok { + return rf(ctx, req) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) *rpc.CheckResult); ok { - r0 = rf(ctx, text, sessionId) + if rf, ok := ret.Get(0).(func(context.Context, types.LLMCheckRequest) *rpc.CheckResult); ok { + r0 = rf(ctx, req) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*rpc.CheckResult) } } - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, text, sessionId) + if rf, ok := ret.Get(1).(func(context.Context, types.LLMCheckRequest) error); ok { + r1 = rf(ctx, req) } else { r1 = ret.Error(1) } @@ -242,15 +241,14 @@ type MockModerationSvcClient_PassLLMRespCheck_Call struct { // PassLLMRespCheck is a helper method to define mock.On call // - ctx context.Context -// - text string -// - sessionId string -func (_e *MockModerationSvcClient_Expecter) PassLLMRespCheck(ctx interface{}, text interface{}, sessionId interface{}) *MockModerationSvcClient_PassLLMRespCheck_Call { - return &MockModerationSvcClient_PassLLMRespCheck_Call{Call: _e.mock.On("PassLLMRespCheck", ctx, text, sessionId)} +// - req types.LLMCheckRequest +func (_e *MockModerationSvcClient_Expecter) PassLLMRespCheck(ctx interface{}, req interface{}) *MockModerationSvcClient_PassLLMRespCheck_Call { + return &MockModerationSvcClient_PassLLMRespCheck_Call{Call: _e.mock.On("PassLLMRespCheck", ctx, req)} } -func (_c *MockModerationSvcClient_PassLLMRespCheck_Call) Run(run func(ctx context.Context, text string, sessionId string)) *MockModerationSvcClient_PassLLMRespCheck_Call { +func (_c *MockModerationSvcClient_PassLLMRespCheck_Call) Run(run func(ctx context.Context, req types.LLMCheckRequest)) *MockModerationSvcClient_PassLLMRespCheck_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(types.LLMCheckRequest)) }) return _c } @@ -260,7 +258,7 @@ func (_c *MockModerationSvcClient_PassLLMRespCheck_Call) Return(_a0 *rpc.CheckRe return _c } -func (_c *MockModerationSvcClient_PassLLMRespCheck_Call) RunAndReturn(run func(context.Context, string, string) (*rpc.CheckResult, error)) *MockModerationSvcClient_PassLLMRespCheck_Call { +func (_c *MockModerationSvcClient_PassLLMRespCheck_Call) RunAndReturn(run func(context.Context, types.LLMCheckRequest) (*rpc.CheckResult, error)) *MockModerationSvcClient_PassLLMRespCheck_Call { _c.Call.Return(run) return _c } diff --git a/_mocks/opencsg.com/csghub-server/builder/sensitive/mock_SensitiveChecker.go b/_mocks/opencsg.com/csghub-server/builder/sensitive/mock_SensitiveChecker.go index 6ce20f059..98766c4ae 100644 --- a/_mocks/opencsg.com/csghub-server/builder/sensitive/mock_SensitiveChecker.go +++ b/_mocks/opencsg.com/csghub-server/builder/sensitive/mock_SensitiveChecker.go @@ -145,9 +145,9 @@ func (_c *MockSensitiveChecker_PassImageURLCheck_Call) RunAndReturn(run func(con return _c } -// PassLLMCheck provides a mock function with given fields: ctx, scenario, text, sessionId, accountId -func (_m *MockSensitiveChecker) PassLLMCheck(ctx context.Context, scenario types.SensitiveScenario, text string, sessionId string, accountId string) (*sensitive.CheckResult, error) { - ret := _m.Called(ctx, scenario, text, sessionId, accountId) +// PassLLMCheck provides a mock function with given fields: ctx, req +func (_m *MockSensitiveChecker) PassLLMCheck(ctx context.Context, req *types.LLMCheckRequest) (*sensitive.CheckResult, error) { + ret := _m.Called(ctx, req) if len(ret) == 0 { panic("no return value specified for PassLLMCheck") @@ -155,19 +155,19 @@ func (_m *MockSensitiveChecker) PassLLMCheck(ctx context.Context, scenario types var r0 *sensitive.CheckResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveScenario, string, string, string) (*sensitive.CheckResult, error)); ok { - return rf(ctx, scenario, text, sessionId, accountId) + if rf, ok := ret.Get(0).(func(context.Context, *types.LLMCheckRequest) (*sensitive.CheckResult, error)); ok { + return rf(ctx, req) } - if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveScenario, string, string, string) *sensitive.CheckResult); ok { - r0 = rf(ctx, scenario, text, sessionId, accountId) + if rf, ok := ret.Get(0).(func(context.Context, *types.LLMCheckRequest) *sensitive.CheckResult); ok { + r0 = rf(ctx, req) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sensitive.CheckResult) } } - if rf, ok := ret.Get(1).(func(context.Context, types.SensitiveScenario, string, string, string) error); ok { - r1 = rf(ctx, scenario, text, sessionId, accountId) + if rf, ok := ret.Get(1).(func(context.Context, *types.LLMCheckRequest) error); ok { + r1 = rf(ctx, req) } else { r1 = ret.Error(1) } @@ -182,17 +182,14 @@ type MockSensitiveChecker_PassLLMCheck_Call struct { // PassLLMCheck is a helper method to define mock.On call // - ctx context.Context -// - scenario types.SensitiveScenario -// - text string -// - sessionId string -// - accountId string -func (_e *MockSensitiveChecker_Expecter) PassLLMCheck(ctx interface{}, scenario interface{}, text interface{}, sessionId interface{}, accountId interface{}) *MockSensitiveChecker_PassLLMCheck_Call { - return &MockSensitiveChecker_PassLLMCheck_Call{Call: _e.mock.On("PassLLMCheck", ctx, scenario, text, sessionId, accountId)} +// - req *types.LLMCheckRequest +func (_e *MockSensitiveChecker_Expecter) PassLLMCheck(ctx interface{}, req interface{}) *MockSensitiveChecker_PassLLMCheck_Call { + return &MockSensitiveChecker_PassLLMCheck_Call{Call: _e.mock.On("PassLLMCheck", ctx, req)} } -func (_c *MockSensitiveChecker_PassLLMCheck_Call) Run(run func(ctx context.Context, scenario types.SensitiveScenario, text string, sessionId string, accountId string)) *MockSensitiveChecker_PassLLMCheck_Call { +func (_c *MockSensitiveChecker_PassLLMCheck_Call) Run(run func(ctx context.Context, req *types.LLMCheckRequest)) *MockSensitiveChecker_PassLLMCheck_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(types.SensitiveScenario), args[2].(string), args[3].(string), args[4].(string)) + run(args[0].(context.Context), args[1].(*types.LLMCheckRequest)) }) return _c } @@ -202,7 +199,7 @@ func (_c *MockSensitiveChecker_PassLLMCheck_Call) Return(_a0 *sensitive.CheckRes return _c } -func (_c *MockSensitiveChecker_PassLLMCheck_Call) RunAndReturn(run func(context.Context, types.SensitiveScenario, string, string, string) (*sensitive.CheckResult, error)) *MockSensitiveChecker_PassLLMCheck_Call { +func (_c *MockSensitiveChecker_PassLLMCheck_Call) RunAndReturn(run func(context.Context, *types.LLMCheckRequest) (*sensitive.CheckResult, error)) *MockSensitiveChecker_PassLLMCheck_Call { _c.Call.Return(run) return _c } diff --git a/_mocks/opencsg.com/csghub-server/moderation/component/mock_SensitiveComponent.go b/_mocks/opencsg.com/csghub-server/moderation/component/mock_SensitiveComponent.go index aeceb991e..15fee0423 100644 --- a/_mocks/opencsg.com/csghub-server/moderation/component/mock_SensitiveComponent.go +++ b/_mocks/opencsg.com/csghub-server/moderation/component/mock_SensitiveComponent.go @@ -145,9 +145,9 @@ func (_c *MockSensitiveComponent_PassImageURLCheck_Call) RunAndReturn(run func(c return _c } -// PassLLMQueryCheck provides a mock function with given fields: ctx, scenario, text, id -func (_m *MockSensitiveComponent) PassLLMQueryCheck(ctx context.Context, scenario types.SensitiveScenario, text string, id string) (*sensitive.CheckResult, error) { - ret := _m.Called(ctx, scenario, text, id) +// PassLLMQueryCheck provides a mock function with given fields: ctx, req +func (_m *MockSensitiveComponent) PassLLMQueryCheck(ctx context.Context, req *types.LLMCheckRequest) (*sensitive.CheckResult, error) { + ret := _m.Called(ctx, req) if len(ret) == 0 { panic("no return value specified for PassLLMQueryCheck") @@ -155,19 +155,19 @@ func (_m *MockSensitiveComponent) PassLLMQueryCheck(ctx context.Context, scenari var r0 *sensitive.CheckResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveScenario, string, string) (*sensitive.CheckResult, error)); ok { - return rf(ctx, scenario, text, id) + if rf, ok := ret.Get(0).(func(context.Context, *types.LLMCheckRequest) (*sensitive.CheckResult, error)); ok { + return rf(ctx, req) } - if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveScenario, string, string) *sensitive.CheckResult); ok { - r0 = rf(ctx, scenario, text, id) + if rf, ok := ret.Get(0).(func(context.Context, *types.LLMCheckRequest) *sensitive.CheckResult); ok { + r0 = rf(ctx, req) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sensitive.CheckResult) } } - if rf, ok := ret.Get(1).(func(context.Context, types.SensitiveScenario, string, string) error); ok { - r1 = rf(ctx, scenario, text, id) + if rf, ok := ret.Get(1).(func(context.Context, *types.LLMCheckRequest) error); ok { + r1 = rf(ctx, req) } else { r1 = ret.Error(1) } @@ -182,16 +182,14 @@ type MockSensitiveComponent_PassLLMQueryCheck_Call struct { // PassLLMQueryCheck is a helper method to define mock.On call // - ctx context.Context -// - scenario types.SensitiveScenario -// - text string -// - id string -func (_e *MockSensitiveComponent_Expecter) PassLLMQueryCheck(ctx interface{}, scenario interface{}, text interface{}, id interface{}) *MockSensitiveComponent_PassLLMQueryCheck_Call { - return &MockSensitiveComponent_PassLLMQueryCheck_Call{Call: _e.mock.On("PassLLMQueryCheck", ctx, scenario, text, id)} +// - req *types.LLMCheckRequest +func (_e *MockSensitiveComponent_Expecter) PassLLMQueryCheck(ctx interface{}, req interface{}) *MockSensitiveComponent_PassLLMQueryCheck_Call { + return &MockSensitiveComponent_PassLLMQueryCheck_Call{Call: _e.mock.On("PassLLMQueryCheck", ctx, req)} } -func (_c *MockSensitiveComponent_PassLLMQueryCheck_Call) Run(run func(ctx context.Context, scenario types.SensitiveScenario, text string, id string)) *MockSensitiveComponent_PassLLMQueryCheck_Call { +func (_c *MockSensitiveComponent_PassLLMQueryCheck_Call) Run(run func(ctx context.Context, req *types.LLMCheckRequest)) *MockSensitiveComponent_PassLLMQueryCheck_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(types.SensitiveScenario), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(*types.LLMCheckRequest)) }) return _c } @@ -201,14 +199,14 @@ func (_c *MockSensitiveComponent_PassLLMQueryCheck_Call) Return(_a0 *sensitive.C return _c } -func (_c *MockSensitiveComponent_PassLLMQueryCheck_Call) RunAndReturn(run func(context.Context, types.SensitiveScenario, string, string) (*sensitive.CheckResult, error)) *MockSensitiveComponent_PassLLMQueryCheck_Call { +func (_c *MockSensitiveComponent_PassLLMQueryCheck_Call) RunAndReturn(run func(context.Context, *types.LLMCheckRequest) (*sensitive.CheckResult, error)) *MockSensitiveComponent_PassLLMQueryCheck_Call { _c.Call.Return(run) return _c } -// PassStreamCheck provides a mock function with given fields: ctx, scenario, text, id -func (_m *MockSensitiveComponent) PassStreamCheck(ctx context.Context, scenario types.SensitiveScenario, text string, id string) (*sensitive.CheckResult, error) { - ret := _m.Called(ctx, scenario, text, id) +// PassStreamCheck provides a mock function with given fields: ctx, req +func (_m *MockSensitiveComponent) PassStreamCheck(ctx context.Context, req *types.LLMCheckRequest) (*sensitive.CheckResult, error) { + ret := _m.Called(ctx, req) if len(ret) == 0 { panic("no return value specified for PassStreamCheck") @@ -216,19 +214,19 @@ func (_m *MockSensitiveComponent) PassStreamCheck(ctx context.Context, scenario var r0 *sensitive.CheckResult var r1 error - if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveScenario, string, string) (*sensitive.CheckResult, error)); ok { - return rf(ctx, scenario, text, id) + if rf, ok := ret.Get(0).(func(context.Context, *types.LLMCheckRequest) (*sensitive.CheckResult, error)); ok { + return rf(ctx, req) } - if rf, ok := ret.Get(0).(func(context.Context, types.SensitiveScenario, string, string) *sensitive.CheckResult); ok { - r0 = rf(ctx, scenario, text, id) + if rf, ok := ret.Get(0).(func(context.Context, *types.LLMCheckRequest) *sensitive.CheckResult); ok { + r0 = rf(ctx, req) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*sensitive.CheckResult) } } - if rf, ok := ret.Get(1).(func(context.Context, types.SensitiveScenario, string, string) error); ok { - r1 = rf(ctx, scenario, text, id) + if rf, ok := ret.Get(1).(func(context.Context, *types.LLMCheckRequest) error); ok { + r1 = rf(ctx, req) } else { r1 = ret.Error(1) } @@ -243,16 +241,14 @@ type MockSensitiveComponent_PassStreamCheck_Call struct { // PassStreamCheck is a helper method to define mock.On call // - ctx context.Context -// - scenario types.SensitiveScenario -// - text string -// - id string -func (_e *MockSensitiveComponent_Expecter) PassStreamCheck(ctx interface{}, scenario interface{}, text interface{}, id interface{}) *MockSensitiveComponent_PassStreamCheck_Call { - return &MockSensitiveComponent_PassStreamCheck_Call{Call: _e.mock.On("PassStreamCheck", ctx, scenario, text, id)} +// - req *types.LLMCheckRequest +func (_e *MockSensitiveComponent_Expecter) PassStreamCheck(ctx interface{}, req interface{}) *MockSensitiveComponent_PassStreamCheck_Call { + return &MockSensitiveComponent_PassStreamCheck_Call{Call: _e.mock.On("PassStreamCheck", ctx, req)} } -func (_c *MockSensitiveComponent_PassStreamCheck_Call) Run(run func(ctx context.Context, scenario types.SensitiveScenario, text string, id string)) *MockSensitiveComponent_PassStreamCheck_Call { +func (_c *MockSensitiveComponent_PassStreamCheck_Call) Run(run func(ctx context.Context, req *types.LLMCheckRequest)) *MockSensitiveComponent_PassStreamCheck_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(types.SensitiveScenario), args[2].(string), args[3].(string)) + run(args[0].(context.Context), args[1].(*types.LLMCheckRequest)) }) return _c } @@ -262,7 +258,7 @@ func (_c *MockSensitiveComponent_PassStreamCheck_Call) Return(_a0 *sensitive.Che return _c } -func (_c *MockSensitiveComponent_PassStreamCheck_Call) RunAndReturn(run func(context.Context, types.SensitiveScenario, string, string) (*sensitive.CheckResult, error)) *MockSensitiveComponent_PassStreamCheck_Call { +func (_c *MockSensitiveComponent_PassStreamCheck_Call) RunAndReturn(run func(context.Context, *types.LLMCheckRequest) (*sensitive.CheckResult, error)) *MockSensitiveComponent_PassStreamCheck_Call { _c.Call.Return(run) return _c } diff --git a/aigateway/component/moderation.go b/aigateway/component/moderation.go index 2cf66102c..acd4d9b21 100644 --- a/aigateway/component/moderation.go +++ b/aigateway/component/moderation.go @@ -8,8 +8,10 @@ import ( "log/slog" "regexp" "strings" + "sync" "time" + lru "github.com/hashicorp/golang-lru/v2" "github.com/openai/openai-go/v3" "opencsg.com/csghub-server/aigateway/types" "opencsg.com/csghub-server/builder/rpc" @@ -19,28 +21,224 @@ import ( ) const ( - // max content length for moderation - maxContentLength = 2000 + // max content length + maxContentLength = 6144 // sliding window size slidingWindowSize = 2000 // cache ttl cacheTTL = 24 * time.Hour // moderation cache prefix moderationCachePrpmptPrefix = "moderation:prompt:" + // default session cache size + defaultSessionCacheSize = 10000 + + StreamCheckModeAsync = "async" + StreamCheckModeSync = "sync" + DefaultAsyncBufferMaxChars = 50 ) type Moderation interface { - CheckChatPrompts(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string) (*rpc.CheckResult, error) + CheckChatPrompts(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string, isStream bool) (*rpc.CheckResult, error) CheckChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, uuid string) (*rpc.CheckResult, error) CheckChatNonStreamResponse(ctx context.Context, completion types.ChatCompletion) (*rpc.CheckResult, error) CheckImagePrompts(ctx context.Context, prompt string, uuid string) (*rpc.CheckResult, error) CheckImage(ctx context.Context, completion types.ImageGenerationResponse) (*rpc.CheckResult, error) + CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) +} + +type sessionState struct { + sync.Mutex + buffer strings.Builder + sensitive bool + reason string +} + +type StreamChecker interface { + CheckChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, uuid string) (*rpc.CheckResult, error) + CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) } type moderationImpl struct { - modSvcClient rpc.ModerationSvcClient - cacheClient cache.RedisClient - config *config.Config + modSvcClient rpc.ModerationSvcClient + cacheClient cache.RedisClient + config *config.Config + streamChecker StreamChecker +} + +type syncStreamChecker struct { + modImpl *moderationImpl +} + +func (s *syncStreamChecker) CheckChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, uuid string) (*rpc.CheckResult, error) { + if s.modImpl.modSvcClient == nil { + return &rpc.CheckResult{IsSensitive: false}, nil + } + if len(chunk.Choices) == 0 { + return &rpc.CheckResult{IsSensitive: false}, nil + } + content := chunk.Choices[0].Delta.Content + if strings.TrimSpace(content) == "" { + content = chunk.Choices[0].Delta.ReasoningContent + } + if strings.TrimSpace(content) == "" { + return &rpc.CheckResult{IsSensitive: false}, nil + } + + req := commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: content, + SessionId: uuid, + Resumable: true, + Stream: true, + } + + result, err := s.modImpl.modSvcClient.PassLLMRespCheck(ctx, req) + s.modImpl.postCheck(ctx, result) + return result, err +} + +func (s *syncStreamChecker) CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) { + return &rpc.CheckResult{IsSensitive: false}, nil +} + +type asyncStreamChecker struct { + modImpl *moderationImpl + sessionCache *lru.Cache[string, *sessionState] + maxChars int +} + +func (a *asyncStreamChecker) CheckChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, uuid string) (*rpc.CheckResult, error) { + if a.modImpl.modSvcClient == nil { + return &rpc.CheckResult{IsSensitive: false}, nil + } + if len(chunk.Choices) == 0 { + return &rpc.CheckResult{IsSensitive: false}, nil + } + content := chunk.Choices[0].Delta.Content + if strings.TrimSpace(content) == "" { + content = chunk.Choices[0].Delta.ReasoningContent + } + if strings.TrimSpace(content) == "" { + return &rpc.CheckResult{IsSensitive: false}, nil + } + + req := commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: content, + SessionId: uuid, + Resumable: true, + Stream: true, + } + if a.sessionCache == nil { + slog.Warn("moderation session cache is nil, fallback to sync mode") + result, err := a.modImpl.modSvcClient.PassLLMRespCheck(ctx, req) + a.modImpl.postCheck(ctx, result) + return result, err + } + + state, ok := a.sessionCache.Get(uuid) + if !ok { + state = &sessionState{} + a.sessionCache.Add(uuid, state) + } + + state.Lock() + if state.sensitive { + state.Unlock() + return &rpc.CheckResult{IsSensitive: true, Reason: state.reason}, nil + } + + state.buffer.WriteString(content) + currentLen := state.buffer.Len() + + var textToCheck string + if currentLen >= a.maxChars { + textToCheck = state.buffer.String() + state.buffer.Reset() + } + state.Unlock() + + if textToCheck != "" { + go a.executeAsyncCheck(textToCheck, uuid) + } + + return &rpc.CheckResult{IsSensitive: false}, nil +} + +func (a *asyncStreamChecker) executeAsyncCheck(text string, sessionId string) { + bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + req := commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: text, + SessionId: sessionId, + Resumable: true, + Stream: true, + } + result, err := a.modImpl.modSvcClient.PassLLMRespCheck(bgCtx, req) + if err != nil { + slog.Warn("async moderation check failed", slog.Any("error", err)) + return + } + + if result.IsSensitive { + if a.modImpl.config != nil && a.modImpl.config.AIGateway.ModerationBypassSensitiveCheck { + return + } + + slog.ErrorContext(bgCtx, "sensitive content found asynchronously", slog.Any("reason", result.Reason)) + + if s, ok := a.sessionCache.Get(sessionId); ok { + s.Lock() + s.sensitive = true + s.reason = result.Reason + s.Unlock() + } + } +} + +func (a *asyncStreamChecker) CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) { + if a.sessionCache == nil { + return &rpc.CheckResult{IsSensitive: false}, nil + } + + state, ok := a.sessionCache.Get(uuid) + if !ok { + return &rpc.CheckResult{IsSensitive: false}, nil + } + + state.Lock() + defer func() { + state.Unlock() + a.sessionCache.Remove(uuid) + }() + + if state.sensitive { + return &rpc.CheckResult{IsSensitive: true, Reason: state.reason}, nil + } + + textToCheck := state.buffer.String() + req := commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: textToCheck, + SessionId: uuid, + Resumable: false, + Stream: true, + } + if textToCheck == "" { + // set end text to trigger check of the end of the session stream + go func() { + req.Text = "[Done]" + cancelCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, _ = a.modImpl.modSvcClient.PassLLMRespCheck(cancelCtx, req) + }() + return &rpc.CheckResult{IsSensitive: false}, nil + } + result, err := a.modImpl.modSvcClient.PassLLMRespCheck(ctx, req) + a.modImpl.postCheck(ctx, result) + return result, err } func NewModerationImpl(config *config.Config) Moderation { @@ -52,19 +250,52 @@ func NewModerationImpl(config *config.Config) Moderation { if err != nil { return nil } - return &moderationImpl{ + + modImpl := &moderationImpl{ modSvcClient: rpc.NewModerationSvcHttpClient(fmt.Sprintf("%s:%d", config.Moderation.Host, config.Moderation.Port)), cacheClient: cacheClient, config: config, } + + initStreamChecker(modImpl) + return modImpl } func NewModerationImplWithClient(config *config.Config, modSvcClient rpc.ModerationSvcClient, cacheClient cache.RedisClient) Moderation { - return &moderationImpl{ + modImpl := &moderationImpl{ modSvcClient: modSvcClient, cacheClient: cacheClient, config: config, } + + initStreamChecker(modImpl) + return modImpl +} + +func initStreamChecker(modImpl *moderationImpl) { + isAsync := modImpl.config != nil && modImpl.config.SensitiveCheck.StreamCheckMode == StreamCheckModeAsync + + if isAsync { + sessionCache, err := lru.New[string, *sessionState](defaultSessionCacheSize) + if err != nil { + slog.Error("failed to init moderation session cache, fallback to sync mode", slog.Any("error", err)) + modImpl.streamChecker = &syncStreamChecker{modImpl: modImpl} + return + } + + maxChars := DefaultAsyncBufferMaxChars + if modImpl.config.SensitiveCheck.AsyncBufferMaxChars > 0 { + maxChars = modImpl.config.SensitiveCheck.AsyncBufferMaxChars + } + + modImpl.streamChecker = &asyncStreamChecker{ + modImpl: modImpl, + sessionCache: sessionCache, + maxChars: maxChars, + } + } else { + modImpl.streamChecker = &syncStreamChecker{modImpl: modImpl} + } } func splitContentIntoChunksByWindow(content string) []string { @@ -96,7 +327,7 @@ func splitContentIntoChunksByWindow(content string) []string { //TODO: use cdc to get chunk // used for single chunk or short content -func (modImpl *moderationImpl) checkSingleChunk(ctx context.Context, content, key string) (*rpc.CheckResult, error) { +func (modImpl *moderationImpl) checkSingleChunk(ctx context.Context, content, key string, isStream bool) (*rpc.CheckResult, error) { if modImpl.cacheClient != nil { chunkHash := md5.Sum([]byte(content)) cacheKey := moderationCachePrpmptPrefix + fmt.Sprintf("%x", chunkHash) @@ -110,7 +341,14 @@ func (modImpl *moderationImpl) checkSingleChunk(ctx context.Context, content, ke } } - result, err := modImpl.modSvcClient.PassLLMPromptCheck(ctx, content, key) + req := commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMQueryModeration, + Text: content, + AccountId: key, + Resumable: true, + Stream: isStream, + } + result, err := modImpl.modSvcClient.PassLLMPromptCheck(ctx, req) if err != nil { return nil, err } @@ -135,8 +373,16 @@ func (modImpl *moderationImpl) checkBuffer( content string, currentBufferChunks []string, key string, + isStream bool, ) (*rpc.CheckResult, error) { - result, err := modImpl.modSvcClient.PassLLMPromptCheck(ctx, content, key) + req := commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMQueryModeration, + Text: content, + AccountId: key, + Resumable: true, + Stream: isStream, + } + result, err := modImpl.modSvcClient.PassLLMPromptCheck(ctx, req) if err != nil { return nil, err } @@ -165,7 +411,7 @@ func (modImpl *moderationImpl) checkBuffer( // CheckChatPrompts checks if any of the chat messages contain sensitive content. // It processes each message, extracts text content, and uses CheckLLMPrompt for validation. -func (modImpl *moderationImpl) CheckChatPrompts(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string) (*rpc.CheckResult, error) { +func (modImpl *moderationImpl) CheckChatPrompts(ctx context.Context, messages []openai.ChatCompletionMessageParamUnion, uuid string, isStream bool) (*rpc.CheckResult, error) { if modImpl.modSvcClient == nil { return &rpc.CheckResult{IsSensitive: false}, nil } @@ -207,7 +453,7 @@ func (modImpl *moderationImpl) CheckChatPrompts(ctx context.Context, messages [] } // Check if content is sensitive using existing method - result, err := modImpl.checkLLMPrompt(ctx, content, uuid) + result, err := modImpl.checkLLMPrompt(ctx, content, uuid, isStream) if err != nil { return nil, fmt.Errorf("failed to check message content: %w", err) } @@ -229,11 +475,11 @@ func (modImpl *moderationImpl) CheckChatPrompts(ctx context.Context, messages [] // CheckLLMPrompt checks if the prompt is sensitive. // For long content, it first checks each chunk individually (with caching). // Then, it uses a sliding window to check for sensitive combinations of chunks. -func (modImpl *moderationImpl) checkLLMPrompt(ctx context.Context, content, key string) (*rpc.CheckResult, error) { +func (modImpl *moderationImpl) checkLLMPrompt(ctx context.Context, content, key string, isStream bool) (*rpc.CheckResult, error) { content = strings.ReplaceAll(content, `\\n`, "\n") content = strings.ReplaceAll(content, `\n`, "") if len(content) < maxContentLength { - return modImpl.checkSingleChunk(ctx, content, key) + return modImpl.checkSingleChunk(ctx, content, key, isStream) } chunks := splitContentIntoChunksByWindow(content) @@ -299,7 +545,7 @@ func (modImpl *moderationImpl) checkLLMPrompt(ctx context.Context, content, key } if buffer.Len()+separatorLen+len(chunk) > maxContentLength && buffer.Len() > 0 { - result, err := modImpl.checkBuffer(ctx, buffer.String(), currentBufferChunks, key) + result, err := modImpl.checkBuffer(ctx, buffer.String(), currentBufferChunks, key, isStream) if err != nil { return nil, fmt.Errorf("failed to call moderation on buffer: %w", err) } @@ -322,7 +568,7 @@ func (modImpl *moderationImpl) checkLLMPrompt(ctx context.Context, content, key // Check any remaining content in the buffer. if buffer.Len() > 0 { - result, err := modImpl.checkBuffer(ctx, buffer.String(), currentBufferChunks, key) + result, err := modImpl.checkBuffer(ctx, buffer.String(), currentBufferChunks, key, isStream) if err != nil { return nil, fmt.Errorf("failed to call moderation on remaining buffer: %w", err) } @@ -337,32 +583,10 @@ func (modImpl *moderationImpl) checkLLMPrompt(ctx context.Context, content, key } func (modImpl *moderationImpl) CheckChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, uuid string) (*rpc.CheckResult, error) { - if modImpl.modSvcClient == nil { - return &rpc.CheckResult{IsSensitive: false}, nil - } - if len(chunk.Choices) == 0 { - return &rpc.CheckResult{IsSensitive: false}, nil - } - if chunk.Choices[0].Delta.Content == "" && chunk.Choices[0].Delta.ReasoningContent == "" { - return &rpc.CheckResult{IsSensitive: false}, nil + if modImpl.streamChecker != nil { + return modImpl.streamChecker.CheckChatStreamResponse(ctx, chunk, uuid) } - - var result = &rpc.CheckResult{IsSensitive: false} - var err error - if strings.TrimSpace(chunk.Choices[0].Delta.Content) != "" { - // moderate on content - result, err = modImpl.modSvcClient.PassLLMRespCheck(ctx, chunk.Choices[0].Delta.Content, uuid) - } else if strings.TrimSpace(chunk.Choices[0].Delta.ReasoningContent) != "" { - // moderate on reasoning content - result, err = modImpl.modSvcClient.PassLLMRespCheck(ctx, chunk.Choices[0].Delta.ReasoningContent, uuid) - } else { - slog.Error("Unknown data struct", - slog.Any("raw data", chunk), - slog.Any("unmarshal chunk", chunk)) - } - - modImpl.postCheck(ctx, result) - return result, err + return &rpc.CheckResult{IsSensitive: false}, nil } func (modImpl *moderationImpl) CheckChatNonStreamResponse(ctx context.Context, completion types.ChatCompletion) (*rpc.CheckResult, error) { @@ -394,11 +618,18 @@ func (modImpl *moderationImpl) postCheck(ctx context.Context, result *rpc.CheckR } } +func (modImpl *moderationImpl) CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) { + if modImpl.streamChecker != nil { + return modImpl.streamChecker.CloseStreamCheck(ctx, uuid) + } + return &rpc.CheckResult{IsSensitive: false}, nil +} + func (modImpl *moderationImpl) CheckImagePrompts(ctx context.Context, prompt string, uuid string) (*rpc.CheckResult, error) { if modImpl.modSvcClient == nil { return &rpc.CheckResult{IsSensitive: false}, nil } - return modImpl.checkLLMPrompt(ctx, prompt, uuid) + return modImpl.checkLLMPrompt(ctx, prompt, uuid, false) } func (modImpl *moderationImpl) CheckImage(ctx context.Context, completion types.ImageGenerationResponse) (*rpc.CheckResult, error) { diff --git a/aigateway/component/moderation_test.go b/aigateway/component/moderation_test.go index cbb3929a0..1692f5093 100644 --- a/aigateway/component/moderation_test.go +++ b/aigateway/component/moderation_test.go @@ -2,703 +2,396 @@ package component import ( "context" - "crypto/md5" - "encoding/json" - "errors" - "fmt" - "reflect" - "strings" "testing" + "time" - "github.com/openai/openai-go/v3" - "github.com/openai/openai-go/v3/packages/param" + lru "github.com/hashicorp/golang-lru/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - mock_rpc "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/rpc" - mock_cache "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/builder/store/cache" "opencsg.com/csghub-server/aigateway/types" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/common/config" - common_types "opencsg.com/csghub-server/common/types" + commontypes "opencsg.com/csghub-server/common/types" ) -func TestSplitContentIntoChunksByWindow_Table(t *testing.T) { - longLen := slidingWindowSize*2 + 10 - longStr := strings.Repeat("a", longLen) +// MockModerationSvcClient is a mock of rpc.ModerationSvcClient +type MockModerationSvcClient struct { + mock.Mock +} - // build expected chunks for longStr - var expectedLong []string - for i := 0; i < longLen; i += slidingWindowSize { - end := i + slidingWindowSize - if end > longLen { - end = longLen - } - expectedLong = append(expectedLong, longStr[i:end]) +func (m *MockModerationSvcClient) PassLLMRespCheck(ctx context.Context, req commontypes.LLMCheckRequest) (*rpc.CheckResult, error) { + args := m.Called(ctx, req) + if args.Get(0) != nil { + return args.Get(0).(*rpc.CheckResult), args.Error(1) } + return nil, args.Error(1) +} - tests := []struct { - name string - in string - want []string - }{ - {name: "empty", in: "", want: []string{}}, - {name: "simple sentences", in: "Hello world. How are you? I'm fine!", want: []string{"Hello world", "How are you", "I'm fine"}}, - {name: "leading/trailing/extra punctuation", in: " .hello.. world! ", want: []string{"hello", "world"}}, - {name: "long single sentence", in: longStr, want: expectedLong}, - {name: "mixed with long sentence and short", in: "short. " + longStr + "! tail?", want: append(append([]string{"short"}, expectedLong...), "tail")}, +func (m *MockModerationSvcClient) PassLLMPromptCheck(ctx context.Context, req commontypes.LLMCheckRequest) (*rpc.CheckResult, error) { + args := m.Called(ctx, req) + if args.Get(0) != nil { + return args.Get(0).(*rpc.CheckResult), args.Error(1) } + return nil, args.Error(1) +} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := splitContentIntoChunksByWindow(tt.in) - if !reflect.DeepEqual(got, tt.want) { - t.Fatalf("unexpected result for %q:\ngot: %#v\nwant: %#v", tt.name, got, tt.want) - } - }) +func (m *MockModerationSvcClient) PassTextCheck(ctx context.Context, scenario commontypes.SensitiveScenario, text string) (*rpc.CheckResult, error) { + args := m.Called(ctx, scenario, text) + if args.Get(0) != nil { + return args.Get(0).(*rpc.CheckResult), args.Error(1) } + return nil, args.Error(1) } -func TestModerationImpl_CheckLLMPrompt_WithoutCache(t *testing.T) { - ctx := context.Background() - key := "test-key" - - t.Run("short and not sensitive", func(t *testing.T) { - mockClient := mock_rpc.NewMockModerationSvcClient(t) - moderation := NewModerationImplWithClient(&config.Config{}, mockClient, nil) - content := "this is a short and safe text" - - mockClient.EXPECT().PassLLMPromptCheck(ctx, content, key).Return(&rpc.CheckResult{IsSensitive: false}, nil).Once() - - result, err := moderation.CheckChatPrompts(ctx, []openai.ChatCompletionMessageParamUnion{ - { - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: param.Opt[string]{Value: content}, - }, - }, - }, - }, key) +func (m *MockModerationSvcClient) PassImageCheck(ctx context.Context, scenario commontypes.SensitiveScenario, ossBucketName, ossObjectName string) (*rpc.CheckResult, error) { + args := m.Called(ctx, scenario, ossBucketName, ossObjectName) + if args.Get(0) != nil { + return args.Get(0).(*rpc.CheckResult), args.Error(1) + } + return nil, args.Error(1) +} - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsSensitive) - mockClient.AssertExpectations(t) - }) +func (m *MockModerationSvcClient) PassImageURLCheck(ctx context.Context, scenario commontypes.SensitiveScenario, imageURL string) (*rpc.CheckResult, error) { + args := m.Called(ctx, scenario, imageURL) + if args.Get(0) != nil { + return args.Get(0).(*rpc.CheckResult), args.Error(1) + } + return nil, args.Error(1) +} - t.Run("short and sensitive", func(t *testing.T) { - mockClient := mock_rpc.NewMockModerationSvcClient(t) - moderation := NewModerationImplWithClient(&config.Config{}, mockClient, nil) - content := "this is a short and sensitive text" +func (m *MockModerationSvcClient) SubmitRepoCheck(ctx context.Context, repoType commontypes.RepositoryType, namespace, name string) error { + args := m.Called(ctx, repoType, namespace, name) + return args.Error(0) +} - mockClient.On("PassLLMPromptCheck", ctx, content, key).Return(&rpc.CheckResult{IsSensitive: true, Reason: "sensitive"}, nil).Once() +// MockStreamChecker is a mock of StreamChecker +type MockStreamChecker struct { + mock.Mock +} - result, err := moderation.CheckChatPrompts(ctx, []openai.ChatCompletionMessageParamUnion{ - { - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: param.Opt[string]{Value: content}, - }, - }, - }, - }, key) +func (m *MockStreamChecker) CheckChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, uuid string) (*rpc.CheckResult, error) { + args := m.Called(ctx, chunk, uuid) + if args.Get(0) != nil { + return args.Get(0).(*rpc.CheckResult), args.Error(1) + } + return nil, args.Error(1) +} - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsSensitive) - mockClient.AssertExpectations(t) - }) +func (m *MockStreamChecker) CloseStreamCheck(ctx context.Context, uuid string) (*rpc.CheckResult, error) { + args := m.Called(ctx, uuid) + if args.Get(0) != nil { + return args.Get(0).(*rpc.CheckResult), args.Error(1) + } + return nil, args.Error(1) } -func TestModerationImpl_CheckChatStreamResponse(t *testing.T) { +func TestSyncStreamChecker_CheckChatStreamResponse(t *testing.T) { ctx := context.Background() - uuid := "test-uuid" + mockSvcClient := new(MockModerationSvcClient) - t.Run("should_return_non_sensitive_when_modSvcClient_is_nil", func(t *testing.T) { - modImpl := &moderationImpl{ - modSvcClient: nil, - cacheClient: nil, - } + modImpl := &moderationImpl{ + modSvcClient: mockSvcClient, + } - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "test content", - }, - }}, - } + checker := &syncStreamChecker{ + modImpl: modImpl, + } - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) + t.Run("empty chunk", func(t *testing.T) { + res, err := checker.CheckChatStreamResponse(ctx, types.ChatCompletionChunk{}, "uuid-1") assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) + assert.False(t, res.IsSensitive) }) - t.Run("should_return_non_sensitive_when_choices_is_empty", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } - + t.Run("normal text check pass", func(t *testing.T) { chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{}, + Choices: []types.ChatCompletionChunkChoice{ + {Delta: types.ChatCompletionChunkChoiceDelta{Content: "hello"}}, + }, } - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) + mockSvcClient.On("PassLLMRespCheck", ctx, commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: "hello", + SessionId: "uuid-2", + Resumable: true, + Stream: true, + }).Return(&rpc.CheckResult{IsSensitive: false}, nil).Once() + res, err := checker.CheckChatStreamResponse(ctx, chunk, "uuid-2") + assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) + assert.False(t, res.IsSensitive) + mockSvcClient.AssertExpectations(t) }) - t.Run("should_return_non_sensitive_when_both_content_and_reasoning_are_empty", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } - + t.Run("sensitive text block", func(t *testing.T) { chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "", - ReasoningContent: "", - }, - }}, + Choices: []types.ChatCompletionChunkChoice{ + {Delta: types.ChatCompletionChunkChoiceDelta{Content: "bad words"}}, + }, } - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) + mockSvcClient.On("PassLLMRespCheck", ctx, commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: "bad words", + SessionId: "uuid-3", + Resumable: true, + Stream: true, + }).Return(&rpc.CheckResult{IsSensitive: true, Reason: "toxic"}, nil).Once() + res, err := checker.CheckChatStreamResponse(ctx, chunk, "uuid-3") + assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) + assert.True(t, res.IsSensitive) + assert.Equal(t, "toxic", res.Reason) + mockSvcClient.AssertExpectations(t) }) +} - t.Run("should_call_PassLLMRespCheck_and_return_non_sensitive_when_content_not_empty", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassLLMRespCheck(ctx, "test content", uuid). - Return(&rpc.CheckResult{IsSensitive: false}, nil).Once() - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } +func TestSyncStreamChecker_CloseStreamCheck(t *testing.T) { + checker := &syncStreamChecker{} + res, err := checker.CloseStreamCheck(context.Background(), "uuid-1") + assert.NoError(t, err) + assert.False(t, res.IsSensitive) +} - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "test content", - }, - }}, - } +func TestAsyncStreamChecker_CheckChatStreamResponse(t *testing.T) { + ctx := context.Background() + mockSvcClient := new(MockModerationSvcClient) - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) - mockModClient.AssertExpectations(t) - }) + modImpl := &moderationImpl{ + modSvcClient: mockSvcClient, + } - t.Run("should_call_PassLLMRespCheck_and_return_sensitive_when_content_is_sensitive", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassLLMRespCheck(ctx, "sensitive content", uuid). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate language"}, nil).Once() - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } + sessionCache, _ := lru.New[string, *sessionState](100) - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "sensitive content", - }, - }}, - } + checker := &asyncStreamChecker{ + modImpl: modImpl, + sessionCache: sessionCache, + maxChars: 10, + } - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) + t.Run("empty chunk", func(t *testing.T) { + res, err := checker.CheckChatStreamResponse(ctx, types.ChatCompletionChunk{}, "uuid-1") assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, true, result.IsSensitive) - mockModClient.AssertExpectations(t) + assert.False(t, res.IsSensitive) }) - t.Run("should_check_reasoning_content_when_content_is_whitespace", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassLLMRespCheck(ctx, "reasoning content", uuid). - Return(&rpc.CheckResult{IsSensitive: false}, nil).Once() - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } - - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: " ", - ReasoningContent: "reasoning content", - }, - }}, + t.Run("accumulate chunks", func(t *testing.T) { + // First chunk - short, should not trigger check + chunk1 := types.ChatCompletionChunk{ + Choices: []types.ChatCompletionChunkChoice{ + {Delta: types.ChatCompletionChunkChoiceDelta{Content: "hello"}}, + }, } - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) + res1, err := checker.CheckChatStreamResponse(ctx, chunk1, "uuid-2") assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) - mockModClient.AssertExpectations(t) - }) + assert.False(t, res1.IsSensitive) - t.Run("should_call_PassLLMRespCheck_when_reasoning_content_not_empty", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassLLMRespCheck(ctx, "reasoning content", uuid). - Return(&rpc.CheckResult{IsSensitive: false}, nil).Once() - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, + // Second chunk - total length > maxChars, should trigger async check + chunk2 := types.ChatCompletionChunk{ + Choices: []types.ChatCompletionChunkChoice{ + {Delta: types.ChatCompletionChunkChoiceDelta{Content: " world"}}, + }, } - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "", - ReasoningContent: "reasoning content", - }, - }}, - } + // Setup mock for the async call + mockSvcClient.On("PassLLMRespCheck", mock.Anything, commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: "hello world", + SessionId: "uuid-2", + Resumable: true, + Stream: true, + }).Return(&rpc.CheckResult{IsSensitive: false}, nil).Once() - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) + res2, err := checker.CheckChatStreamResponse(ctx, chunk2, "uuid-2") assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) - mockModClient.AssertExpectations(t) - }) + assert.False(t, res2.IsSensitive) - t.Run("should_return_error_when_PassLLMRespCheck_fails", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassLLMRespCheck(ctx, "test content", uuid). - Return(&rpc.CheckResult{IsSensitive: false}, assert.AnError).Once() - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } - - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "test content", - }, - }}, - } - - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) - assert.Error(t, err) - assert.NotNil(t, result) - mockModClient.AssertExpectations(t) + // Wait a bit for the async goroutine to complete + time.Sleep(100 * time.Millisecond) + mockSvcClient.AssertExpectations(t) }) - t.Run("should_return_default_result_when_both_content_and_reasoning_are_whitespace", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, + t.Run("sensitive async result updates cache", func(t *testing.T) { + chunk1 := types.ChatCompletionChunk{ + Choices: []types.ChatCompletionChunkChoice{ + {Delta: types.ChatCompletionChunkChoiceDelta{Content: "very bad word here"}}, + }, } - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: " ", - ReasoningContent: " ", - }, - }}, - } + mockSvcClient.On("PassLLMRespCheck", mock.Anything, commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: "very bad word here", + SessionId: "uuid-3", + Resumable: true, + Stream: true, + }).Return(&rpc.CheckResult{IsSensitive: true, Reason: "toxic"}, nil).Once() - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) + res1, err := checker.CheckChatStreamResponse(ctx, chunk1, "uuid-3") assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) - }) -} - -func TestModerationImpl_CheckChatNonStreamResponse(t *testing.T) { - ctx := context.Background() - t.Run("should_call_PassLLMRespCheck_and_return_sensitive_when_content_is_sensitive", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassTextCheck(ctx, common_types.ScenarioChatDetection, "sensitive content"). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate language"}, nil).Once() - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } - - completion := types.ChatCompletion{ - ChatCompletion: openai.ChatCompletion{ - Choices: []openai.ChatCompletionChoice{{ - Message: openai.ChatCompletionMessage{ - Content: "sensitive content", - }, - }}, - }, - } + assert.False(t, res1.IsSensitive) // Initial response is always non-sensitive while async check runs - result, err := modImpl.CheckChatNonStreamResponse(ctx, completion) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, true, result.IsSensitive) - mockModClient.AssertExpectations(t) - }) - t.Run("should_call_PassLLMRespCheck_and_return_not_sensitive_when_content_is_not_sensitive", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassTextCheck(ctx, common_types.ScenarioChatDetection, "not sensitive content"). - Return(&rpc.CheckResult{IsSensitive: false, Reason: "appropriate language"}, nil).Once() - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - } + // Wait for async check to complete and update cache + time.Sleep(100 * time.Millisecond) + mockSvcClient.AssertExpectations(t) - completion := types.ChatCompletion{ - ChatCompletion: openai.ChatCompletion{ - Choices: []openai.ChatCompletionChoice{{ - Message: openai.ChatCompletionMessage{ - Content: "not sensitive content", - }, - }}, + // Next chunk should be blocked immediately based on cache + chunk2 := types.ChatCompletionChunk{ + Choices: []types.ChatCompletionChunkChoice{ + {Delta: types.ChatCompletionChunkChoiceDelta{Content: "more"}}, }, } - result, err := modImpl.CheckChatNonStreamResponse(ctx, completion) + res2, err := checker.CheckChatStreamResponse(ctx, chunk2, "uuid-3") assert.NoError(t, err) - assert.NotNil(t, result) - assert.Equal(t, false, result.IsSensitive) - mockModClient.AssertExpectations(t) + assert.True(t, res2.IsSensitive) + assert.Equal(t, "toxic", res2.Reason) }) } -// TestModerationImpl_CheckLLMPrompt_CacheCheck tests the cache checking logic in moderation.go -func TestModerationImpl_CheckLLMPrompt_CacheCheck(t *testing.T) { +func TestAsyncStreamChecker_CloseStreamCheck(t *testing.T) { ctx := context.Background() - key := "test-key" - - // case 1: cache hit - t.Run("cache_client_exists_and_cache_has_sensitive_content", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockCacheClient := mock_cache.NewMockRedisClient(t) - modImpl := &moderationImpl{ - cacheClient: mockCacheClient, - modSvcClient: mockModClient, - } + mockSvcClient := new(MockModerationSvcClient) - sensitiveChunk := "this is a sensitive chunk of content" - safeContent := strings.Repeat("safe content. ", 200) - testContent := sensitiveChunk + ". " + safeContent - - chunkHash := md5.Sum([]byte(sensitiveChunk)) - cacheKey := moderationCachePrpmptPrefix + fmt.Sprintf("%x", chunkHash) + modImpl := &moderationImpl{ + modSvcClient: mockSvcClient, + } - sensitiveResult := &rpc.CheckResult{IsSensitive: true, Reason: "contains inappropriate content"} - resultJSON, _ := json.Marshal(sensitiveResult) - mockCacheClient.EXPECT().Get(ctx, cacheKey).Return(string(resultJSON), nil).Once() + sessionCache, _ := lru.New[string, *sessionState](100) - result, err := modImpl.CheckChatPrompts(ctx, []openai.ChatCompletionMessageParamUnion{ - { - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: param.Opt[string]{Value: testContent}, - }, - }, - }, - }, key) + checker := &asyncStreamChecker{ + modImpl: modImpl, + sessionCache: sessionCache, + maxChars: 10, + } + t.Run("not in cache", func(t *testing.T) { + res, err := checker.CloseStreamCheck(ctx, "uuid-unknown") assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsSensitive) - assert.Equal(t, "contains inappropriate content", result.Reason) - mockCacheClient.AssertExpectations(t) - mockModClient.AssertExpectations(t) + assert.False(t, res.IsSensitive) }) - // case 2: cache failed - t.Run("cache_get_error_but_does_not_affect_overall_functionality", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockCacheClient := mock_cache.NewMockRedisClient(t) - - modImpl := &moderationImpl{ - cacheClient: mockCacheClient, - modSvcClient: mockModClient, - } - testChunk := "this is a test chunk of content" - testContent := testChunk + ". " + strings.Repeat("y", slidingWindowSize*2) - - chunkHash := md5.Sum([]byte(testChunk)) - cacheKey1 := moderationCachePrpmptPrefix + fmt.Sprintf("%x", chunkHash) - cacheKey2 := moderationCachePrpmptPrefix + fmt.Sprintf("%x", md5.Sum([]byte(strings.Repeat("y", slidingWindowSize)))) - - mockCacheClient.EXPECT().Get(mock.Anything, cacheKey1).Return("", errors.New("cache error")) - - mockCacheClient.EXPECT().Get(mock.Anything, cacheKey2).Return("", errors.New("cache error")) - - mockModClient.EXPECT().PassLLMPromptCheck(mock.Anything, mock.Anything, key). - Return(&rpc.CheckResult{IsSensitive: false}, nil) - - mockCacheClient.EXPECT().SetEx(mock.Anything, cacheKey1, mock.Anything, cacheTTL). - Return(nil) - mockCacheClient.EXPECT().SetEx(mock.Anything, cacheKey2, mock.Anything, cacheTTL). - Return(nil) - result, err := modImpl.CheckChatPrompts(ctx, []openai.ChatCompletionMessageParamUnion{ - { - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: param.Opt[string]{Value: testContent}, - }, - }, + t.Run("check remaining buffer", func(t *testing.T) { + // Put something in buffer first + chunk := types.ChatCompletionChunk{ + Choices: []types.ChatCompletionChunkChoice{ + {Delta: types.ChatCompletionChunkChoiceDelta{Content: "short"}}, }, - }, key) + } + res1, err := checker.CheckChatStreamResponse(ctx, chunk, "uuid-4") assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsSensitive) - }) -} - -// TestModerationImpl_PostCheck tests the postCheck function with ModerationBypassSensitiveCheck config -func TestModerationImpl_PostCheck(t *testing.T) { - ctx := context.Background() - - t.Run("should_not_modify_non_sensitive_result", func(t *testing.T) { - modImpl := &moderationImpl{ - config: &config.Config{}, - } - result := &rpc.CheckResult{IsSensitive: false} - modImpl.postCheck(ctx, result) - assert.NotNil(t, result) - assert.False(t, result.IsSensitive) - }) + assert.False(t, res1.IsSensitive) + + // Close should trigger check on remaining "short" + mockSvcClient.On("PassLLMRespCheck", ctx, commontypes.LLMCheckRequest{ + Scenario: commontypes.ScenarioLLMResModeration, + Text: "short", + SessionId: "uuid-4", + Resumable: false, + Stream: true, + }).Return(&rpc.CheckResult{IsSensitive: false}, nil).Once() + + res2, err := checker.CloseStreamCheck(ctx, "uuid-4") + assert.NoError(t, err) + assert.False(t, res2.IsSensitive) + mockSvcClient.AssertExpectations(t) - t.Run("should_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_false", func(t *testing.T) { - cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = false - modImpl := &moderationImpl{ - config: cfg, - } - result := &rpc.CheckResult{IsSensitive: true, Reason: "test reason"} - modImpl.postCheck(ctx, result) - assert.NotNil(t, result) - assert.True(t, result.IsSensitive, "should keep IsSensitive as true when ModerationBypassSensitiveCheck is false") - assert.Equal(t, "test reason", result.Reason) + // Verify it was removed from cache + _, exists := sessionCache.Get("uuid-4") + assert.False(t, exists) }) - t.Run("should_not_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_true", func(t *testing.T) { - cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = true - modImpl := &moderationImpl{ - config: cfg, + t.Run("already marked sensitive", func(t *testing.T) { + // Create a sensitive state directly in cache + state := &sessionState{ + sensitive: true, + reason: "toxic", } - result := &rpc.CheckResult{IsSensitive: true, Reason: "test reason"} - modImpl.postCheck(ctx, result) - assert.NotNil(t, result) - assert.False(t, result.IsSensitive, "should change IsSensitive to false when ModerationBypassSensitiveCheck is true") - assert.Equal(t, "", result.Reason, "should clear the reason when bypassing") - }) + sessionCache.Add("uuid-5", state) - t.Run("should_block_sensitive_content_when_config_is_nil", func(t *testing.T) { - modImpl := &moderationImpl{ - config: nil, - } - result := &rpc.CheckResult{IsSensitive: true, Reason: "test reason"} - modImpl.postCheck(ctx, result) - assert.NotNil(t, result) - assert.True(t, result.IsSensitive, "should keep IsSensitive as true when config is nil (default behavior)") - assert.Equal(t, "test reason", result.Reason) + res, err := checker.CloseStreamCheck(ctx, "uuid-5") + assert.NoError(t, err) + assert.True(t, res.IsSensitive) + assert.Equal(t, "toxic", res.Reason) }) } -// TestModerationImpl_CheckChatPrompts_WithModerationBypass tests CheckChatPrompts with ModerationBypassSensitiveCheck config -func TestModerationImpl_CheckChatPrompts_WithModerationBypass(t *testing.T) { +func TestModerationImpl_CheckChatStreamResponse(t *testing.T) { ctx := context.Background() - key := "test-key" - - t.Run("should_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_false", func(t *testing.T) { - mockClient := mock_rpc.NewMockModerationSvcClient(t) - content := "sensitive content" - - mockClient.On("PassLLMPromptCheck", ctx, content, key). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate"}, nil).Once() - - cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = false - modImpl := &moderationImpl{ - modSvcClient: mockClient, - cacheClient: nil, - config: cfg, - } - - result, err := modImpl.CheckChatPrompts(ctx, []openai.ChatCompletionMessageParamUnion{ - { - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: param.Opt[string]{Value: content}, - }, - }, - }, - }, key) - - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsSensitive, "should block sensitive content when ModerationBypassSensitiveCheck is false") - mockClient.AssertExpectations(t) - }) + mockChecker := new(MockStreamChecker) + modImpl := &moderationImpl{ + streamChecker: mockChecker, + } - t.Run("should_not_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_true", func(t *testing.T) { - mockClient := mock_rpc.NewMockModerationSvcClient(t) - content := "sensitive content" + chunk := types.ChatCompletionChunk{ID: "test-id"} + uuid := "uuid-1" + expectedResult := &rpc.CheckResult{IsSensitive: true, Reason: "toxic"} - mockClient.On("PassLLMPromptCheck", ctx, content, key). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate"}, nil).Once() + mockChecker.On("CheckChatStreamResponse", ctx, chunk, uuid).Return(expectedResult, nil).Once() - cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = true - modImpl := &moderationImpl{ - modSvcClient: mockClient, - cacheClient: nil, - config: cfg, - } + res, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) - result, err := modImpl.CheckChatPrompts(ctx, []openai.ChatCompletionMessageParamUnion{ - { - OfSystem: &openai.ChatCompletionSystemMessageParam{ - Content: openai.ChatCompletionSystemMessageParamContentUnion{ - OfString: param.Opt[string]{Value: content}, - }, - }, - }, - }, key) - - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsSensitive, "should not block sensitive content when ModerationBypassSensitiveCheck is true") - mockClient.AssertExpectations(t) - }) + assert.NoError(t, err) + assert.Equal(t, expectedResult, res) + mockChecker.AssertExpectations(t) } -// TestModerationImpl_CheckChatStreamResponse_WithModerationBypass tests CheckChatStreamResponse with ModerationBypassSensitiveCheck config -func TestModerationImpl_CheckChatStreamResponse_WithModerationBypass(t *testing.T) { +func TestModerationImpl_CloseStreamCheck(t *testing.T) { ctx := context.Background() - uuid := "test-uuid" + mockChecker := new(MockStreamChecker) + modImpl := &moderationImpl{ + streamChecker: mockChecker, + } - t.Run("should_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_false", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassLLMRespCheck(ctx, "sensitive content", uuid). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate language"}, nil).Once() + uuid := "uuid-1" + expectedResult := &rpc.CheckResult{IsSensitive: false} - cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = false - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - config: cfg, - } + mockChecker.On("CloseStreamCheck", ctx, uuid).Return(expectedResult, nil).Once() - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "sensitive content", - }, - }}, - } + res, err := modImpl.CloseStreamCheck(ctx, uuid) - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsSensitive, "should block sensitive content when ModerationBypassSensitiveCheck is false") - mockModClient.AssertExpectations(t) - }) - - t.Run("should_not_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_true", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassLLMRespCheck(ctx, "sensitive content", uuid). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate language"}, nil).Once() + assert.NoError(t, err) + assert.Equal(t, expectedResult, res) + mockChecker.AssertExpectations(t) +} +func TestInitStreamChecker(t *testing.T) { + t.Run("sync mode", func(t *testing.T) { cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = true - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - config: cfg, - } + cfg.SensitiveCheck.StreamCheckMode = StreamCheckModeSync + modImpl := &moderationImpl{config: cfg} - chunk := types.ChatCompletionChunk{ - Choices: []types.ChatCompletionChunkChoice{{ - Delta: types.ChatCompletionChunkChoiceDelta{ - Content: "sensitive content", - }, - }}, - } + initStreamChecker(modImpl) - result, err := modImpl.CheckChatStreamResponse(ctx, chunk, uuid) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsSensitive, "should not block sensitive content when ModerationBypassSensitiveCheck is true") - mockModClient.AssertExpectations(t) + _, ok := modImpl.streamChecker.(*syncStreamChecker) + assert.True(t, ok) }) -} - -// TestModerationImpl_CheckChatNonStreamResponse_WithModerationBypass tests CheckChatNonStreamResponse with ModerationBypassSensitiveCheck config -func TestModerationImpl_CheckChatNonStreamResponse_WithModerationBypass(t *testing.T) { - ctx := context.Background() - - t.Run("should_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_false", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassTextCheck(ctx, common_types.ScenarioChatDetection, "sensitive content"). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate language"}, nil).Once() + t.Run("async mode", func(t *testing.T) { cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = false - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - config: cfg, - } + cfg.SensitiveCheck.StreamCheckMode = StreamCheckModeAsync + modImpl := &moderationImpl{config: cfg} - completion := types.ChatCompletion{ - ChatCompletion: openai.ChatCompletion{ - Choices: []openai.ChatCompletionChoice{{ - Message: openai.ChatCompletionMessage{ - Content: "sensitive content", - }, - }}, - }, - } + initStreamChecker(modImpl) - result, err := modImpl.CheckChatNonStreamResponse(ctx, completion) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.True(t, result.IsSensitive, "should block sensitive content when ModerationBypassSensitiveCheck is false") - mockModClient.AssertExpectations(t) + checker, ok := modImpl.streamChecker.(*asyncStreamChecker) + assert.True(t, ok) + assert.NotNil(t, checker.sessionCache) + assert.Equal(t, DefaultAsyncBufferMaxChars, checker.maxChars) }) - t.Run("should_not_block_sensitive_content_when_ModerationBypassSensitiveCheck_is_true", func(t *testing.T) { - mockModClient := mock_rpc.NewMockModerationSvcClient(t) - mockModClient.EXPECT().PassTextCheck(ctx, common_types.ScenarioChatDetection, "sensitive content"). - Return(&rpc.CheckResult{IsSensitive: true, Reason: "inappropriate language"}, nil).Once() - + t.Run("async mode with custom max chars", func(t *testing.T) { cfg := &config.Config{} - cfg.AIGateway.ModerationBypassSensitiveCheck = true - modImpl := &moderationImpl{ - modSvcClient: mockModClient, - cacheClient: nil, - config: cfg, - } + cfg.SensitiveCheck.StreamCheckMode = StreamCheckModeAsync + cfg.SensitiveCheck.AsyncBufferMaxChars = 100 + modImpl := &moderationImpl{config: cfg} - completion := types.ChatCompletion{ - ChatCompletion: openai.ChatCompletion{ - Choices: []openai.ChatCompletionChoice{{ - Message: openai.ChatCompletionMessage{ - Content: "sensitive content", - }, - }}, - }, - } + initStreamChecker(modImpl) - result, err := modImpl.CheckChatNonStreamResponse(ctx, completion) - assert.NoError(t, err) - assert.NotNil(t, result) - assert.False(t, result.IsSensitive, "should not block sensitive content when ModerationBypassSensitiveCheck is true") - mockModClient.AssertExpectations(t) + checker, ok := modImpl.streamChecker.(*asyncStreamChecker) + assert.True(t, ok) + assert.Equal(t, 100, checker.maxChars) }) } diff --git a/aigateway/component/openai.go b/aigateway/component/openai.go index 2bad02b5f..8093942ad 100644 --- a/aigateway/component/openai.go +++ b/aigateway/component/openai.go @@ -290,8 +290,9 @@ func (m *openaiComponentImpl) getExternalModels(c context.Context) []types.Model }, Endpoint: extModel.ApiEndpoint, ExternalModelInfo: types.ExternalModelInfo{ - Provider: extModel.Provider, - AuthHead: extModel.AuthHeader, + Provider: extModel.Provider, + AuthHead: extModel.AuthHeader, + NeedSensitiveCheck: extModel.NeedSensitiveCheck, }, } models = append(models, m) diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index bde6dd681..5bf9fd4b9 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -337,9 +337,11 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { sceneValue := c.Request.Header.Get(commonType.SceneHeaderKey) // Check balance before processing request - if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, userUUID); err != nil { - h.handleInsufficientBalance(c, chatReq.Stream, username, modelID, err) - return + if !model.SkipBalance() { + if err := h.openaiComponent.CheckBalance(c.Request.Context(), username, userUUID); err != nil { + h.handleInsufficientBalance(c, chatReq.Stream, username, modelID, err) + return + } } // marshal updated request map back to JSON bytes @@ -352,26 +354,26 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { c.String(http.StatusInternalServerError, fmt.Errorf("failed to create reverse proxy:%w", err).Error()) return } - slog.InfoContext(c.Request.Context(), "proxy chat request to model target", slog.Any("target", target), slog.Any("host", host), - slog.Any("user", username), slog.Any("model_name", modelName)) - // Create a combined key using userUUID and modelID for caching and tracking - key := fmt.Sprintf("%s:%s", userUUID, modelID) - result, err := h.modComponent.CheckChatPrompts(c.Request.Context(), chatReq.Messages, key) - if err != nil { - c.String(http.StatusInternalServerError, fmt.Errorf("failed to call moderation error:%w", err).Error()) - return - } - if result.IsSensitive { - slog.DebugContext(c.Request.Context(), "sensitive content", slog.String("reason", result.Reason)) - errorChunk := generateSensitiveRespForPrompt() - errorChunkJson, _ := json.Marshal(errorChunk) - _, err := c.Writer.Write([]byte("data: " + string(errorChunkJson) + "\n\n" + "[DONE]")) + + var modComponent component.Moderation = nil + if model.NeedSensitiveCheck { + modComponent = h.modComponent + // Create a combined key using userUUID and modelID for caching and tracking + key := fmt.Sprintf("%s:%s", userUUID, modelID) + result, err := h.modComponent.CheckChatPrompts(c.Request.Context(), chatReq.Messages, key, chatReq.Stream) if err != nil { - slog.ErrorContext(c.Request.Context(), "write into resp error:", slog.String("err", err.Error())) + c.String(http.StatusInternalServerError, fmt.Errorf("failed to call moderation error:%w", err).Error()) + return + } + if result.IsSensitive { + handleSensitiveResponse(c, chatReq.Stream, result) + return } - c.Writer.Flush() - return } + + slog.InfoContext(c.Request.Context(), "proxy chat request to model target", slog.Any("target", target), slog.Any("host", host), + slog.Any("user", username), slog.Any("model_name", modelName)) + tokenCounter := h.tokenCounterFactory.NewChat(token.CreateParam{ Endpoint: target, Host: host, @@ -379,7 +381,8 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) { ImageID: model.ImageID, Provider: model.Provider, }) - w := NewResponseWriterWrapper(c.Writer, chatReq.Stream, h.modComponent, tokenCounter) + + w := NewResponseWriterWrapper(c.Writer, chatReq.Stream, modComponent, tokenCounter) defer w.ClearBuffer() tokenCounter.AppendPrompts(chatReq.Messages) diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index fcc075072..88c1ee944 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -272,9 +272,12 @@ func TestOpenAIHandler_GetModel(t *testing.T) { tester, c, w := setupTest(t) model := &types.Model{ BaseModel: types.BaseModel{ - ID: "model1:svc1", - Object: "model", - OwnedBy: "testuser", + ID: "model1", + Object: "model", + OwnedBy: "testuser", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, }, } c.Params = []gin.Param{{Key: "model", Value: "model1:svc1"}} @@ -345,13 +348,13 @@ func TestOpenAIHandler_Chat(t *testing.T) { model := &types.Model{ BaseModel: types.BaseModel{ - ID: "model1:svc1", - Object: "model", - OwnedBy: "testuser", - }, + ID: "model1", + Object: "model", + OwnedBy: "testuser", + }, InternalModelInfo: types.InternalModelInfo{ - ClusterID: "test-cls", - SvcName: "test-svc", + ClusterID: "test-cls", + SvcName: "test-svc", CSGHubModelID: "model1", }, } @@ -387,6 +390,9 @@ func TestOpenAIHandler_Chat(t *testing.T) { SvcName: "test-svc", CSGHubModelID: "model1", }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, Endpoint: "test-endpoint", } tester.mocks.mockClsComp.EXPECT().GetClusterByID(mock.Anything, "test-cls").Return(&database.ClusterInfo{ @@ -396,7 +402,7 @@ func TestOpenAIHandler_Chat(t *testing.T) { tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil) expectReq := ChatCompletionRequest{} _ = json.Unmarshal(body, &expectReq) - tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID). + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID, false). Return(&rpc.CheckResult{IsSensitive: true}, nil) tester.handler.Chat(c) @@ -432,6 +438,9 @@ func TestOpenAIHandler_Chat(t *testing.T) { SvcName: "test-svc", CSGHubModelID: "model1", }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, Endpoint: testServer.URL, } tester.mocks.mockClsComp.EXPECT().GetClusterByID(mock.Anything, "test-cls").Return(&database.ClusterInfo{ @@ -441,7 +450,7 @@ func TestOpenAIHandler_Chat(t *testing.T) { tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil) expectReq := ChatCompletionRequest{} _ = json.Unmarshal(body, &expectReq) - tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID). + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID, false). Return(nil, errors.New("some error")) tester.handler.Chat(c) @@ -468,15 +477,18 @@ func TestOpenAIHandler_Chat(t *testing.T) { model := &types.Model{ BaseModel: types.BaseModel{ - ID: "model1:svc1", - Object: "model", - OwnedBy: "testuser", - }, + ID: "model1:svc1", + Object: "model", + OwnedBy: "testuser", + }, InternalModelInfo: types.InternalModelInfo{ ClusterID: "test-cls", SvcName: "test-svc", CSGHubModelID: "model1", }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, Endpoint: testServer.URL, } tester.mocks.mockClsComp.EXPECT().GetClusterByID(mock.Anything, "test-cls").Return(&database.ClusterInfo{ @@ -486,7 +498,7 @@ func TestOpenAIHandler_Chat(t *testing.T) { tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil) expectReq := ChatCompletionRequest{} _ = json.Unmarshal(body, &expectReq) - tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID). + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID, false). Return(&rpc.CheckResult{IsSensitive: false}, nil) llmTokenCounter := mocktoken.NewMockChatTokenCounter(t) tester.mocks.tokenCounterFactory.EXPECT().NewChat( @@ -495,6 +507,7 @@ func TestOpenAIHandler_Chat(t *testing.T) { Host: "", Model: "model1", ImageID: model.ImageID, + Provider: model.Provider, }). Return(llmTokenCounter) llmTokenCounter.EXPECT().AppendPrompts(expectReq.Messages).Return() @@ -539,6 +552,9 @@ func TestOpenAIHandler_Chat(t *testing.T) { SvcName: "test-svc", CSGHubModelID: "model1", }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, Endpoint: testServer.URL, } tester.mocks.mockClsComp.EXPECT().GetClusterByID(mock.Anything, "test-cls").Return(&database.ClusterInfo{ @@ -548,7 +564,7 @@ func TestOpenAIHandler_Chat(t *testing.T) { tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil) expectReq := ChatCompletionRequest{} _ = json.Unmarshal(body, &expectReq) - tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID). + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID, false). Return(&rpc.CheckResult{IsSensitive: false}, nil) llmTokenCounter := mocktoken.NewMockChatTokenCounter(t) tester.mocks.tokenCounterFactory.EXPECT().NewChat( @@ -557,6 +573,7 @@ func TestOpenAIHandler_Chat(t *testing.T) { Host: "", Model: "model1", ImageID: model.ImageID, + Provider: model.Provider, }). Return(llmTokenCounter) llmTokenCounter.EXPECT().AppendPrompts(expectReq.Messages).Return() @@ -597,13 +614,16 @@ func TestOpenAIHandler_Chat(t *testing.T) { InternalModelInfo: types.InternalModelInfo{ SvcName: "", }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, + }, Endpoint: testServer.URL, } tester.mocks.openAIComp.EXPECT().GetModelByID(mock.Anything, "testuser", "external-model-id").Return(model, nil) tester.mocks.openAIComp.EXPECT().CheckBalance(mock.Anything, "testuser", "testuuid").Return(nil) expectReq := ChatCompletionRequest{} _ = json.Unmarshal(body, &expectReq) - tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID). + tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID, false). Return(&rpc.CheckResult{IsSensitive: false}, nil) llmTokenCounter := mocktoken.NewMockChatTokenCounter(t) tester.mocks.tokenCounterFactory.EXPECT().NewChat( @@ -612,6 +632,7 @@ func TestOpenAIHandler_Chat(t *testing.T) { Host: "", Model: model.ID, ImageID: model.ImageID, + Provider: model.Provider, }). Return(llmTokenCounter) llmTokenCounter.EXPECT().AppendPrompts(expectReq.Messages).Return() @@ -743,9 +764,12 @@ func TestOpenAIHandler_Embedding(t *testing.T) { model := &types.Model{ BaseModel: types.BaseModel{ - ID: "model1:svc1", - Object: "model", - OwnedBy: "testuser", + ID: "model1:svc1", + Object: "model", + OwnedBy: "testuser", + }, + ExternalModelInfo: types.ExternalModelInfo{ + NeedSensitiveCheck: true, }, InternalModelInfo: types.InternalModelInfo{ ClusterID: "test-cls", diff --git a/aigateway/handler/response_writer_wrapper.go b/aigateway/handler/response_writer_wrapper.go index 62bcea7a7..cdad791af 100644 --- a/aigateway/handler/response_writer_wrapper.go +++ b/aigateway/handler/response_writer_wrapper.go @@ -16,6 +16,7 @@ import ( "opencsg.com/csghub-server/aigateway/component" "opencsg.com/csghub-server/aigateway/token" "opencsg.com/csghub-server/aigateway/types" + rpc "opencsg.com/csghub-server/builder/rpc" ) type CommonResponseWriter interface { @@ -69,6 +70,7 @@ func (rw *ResponseWriterWrapper) Header() http.Header { } func (rw *ResponseWriterWrapper) WriteHeader(statusCode int) { + rw.internalWritter.Header().Del("Content-Length") rw.internalWritter.WriteHeader(statusCode) } @@ -84,6 +86,19 @@ func (rw *ResponseWriterWrapper) streamWrite(data []byte) (int, error) { continue } if string(event.Data) == "[DONE]" { + // trigger async check for sensitive content on remaining buffer + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + res, err := rw.closeStreamCheck(ctx, rw.id) + if err != nil { + slog.Error("ResponseWriterWrapper streamWrite closeStreamCheck error", slog.Any("err", err)) + rw.writeInternal(event.Raw) + continue + } + if res != nil && res.IsSensitive { + return rw.handleSensitiveResult(res, event.Raw, types.ChatCompletionChunk{}) + } + rw.writeInternal(event.Raw) return len(data), nil } @@ -101,21 +116,14 @@ func (rw *ResponseWriterWrapper) streamWrite(data []byte) (int, error) { // call moderation service ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - result, err := rw.moderationComponent.CheckChatStreamResponse(ctx, chunk, rw.id) + result, err := rw.checkChatStreamResponse(ctx, chunk, rw.id) if err != nil { - slog.Error("ResponseWriterWrapper streamWrite checkChatResponse error", slog.Any("err", err)) + slog.Error("ResponseWriterWrapper streamWrite checkChatStreamResponse error", slog.Any("err", err)) rw.writeInternal(event.Raw) continue } - if result.IsSensitive { - slog.Debug("ResponseWriterWrapper streamWrite checkresult is sensitive", - slog.Any("content", chunk), - slog.Any("reason", result.Reason)) - chunk = rw.generateSensitiveRespForContent(chunk) - chunkJson, _ := json.Marshal(chunk) - rw.writeInternal([]byte("data: " + string(chunkJson) + "\n\n")) - rw.writeInternal([]byte("data: [DONE]\n\n")) - return 0, ErrSensitiveContent + if result != nil && result.IsSensitive { + return rw.handleSensitiveResult(result, event.Raw, chunk) } rw.writeInternal(event.Raw) } @@ -123,6 +131,31 @@ func (rw *ResponseWriterWrapper) streamWrite(data []byte) (int, error) { return len(data), nil } +func (rw *ResponseWriterWrapper) closeStreamCheck(ctx context.Context, id string) (*rpc.CheckResult, error) { + if rw.moderationComponent == nil { + return nil, nil + } + return rw.moderationComponent.CloseStreamCheck(ctx, id) +} + +func (rw *ResponseWriterWrapper) checkChatStreamResponse(ctx context.Context, chunk types.ChatCompletionChunk, id string) (*rpc.CheckResult, error) { + if rw.moderationComponent == nil { + return nil, nil + } + return rw.moderationComponent.CheckChatStreamResponse(ctx, chunk, id) +} + +func (rw *ResponseWriterWrapper) handleSensitiveResult(result *rpc.CheckResult, rawData []byte, chunk types.ChatCompletionChunk) (int, error) { + slog.Debug("ResponseWriterWrapper streamWrite checkresult is sensitive", + slog.Any("content", chunk), + slog.Any("reason", result.Reason)) + chunk = rw.generateSensitiveRespForContent(chunk) + chunkJson, _ := json.Marshal(chunk) + rw.writeInternal([]byte("data: " + string(chunkJson) + "\n\n")) + rw.writeInternal([]byte("data: [DONE]\n\n")) + return 0, ErrSensitiveContent +} + func (rw *ResponseWriterWrapper) writeInternal(data []byte) { slog.Debug("writeInternal", slog.String("data", string(data))) _, err := rw.internalWritter.Write(data) @@ -147,6 +180,10 @@ func (rw *ResponseWriterWrapper) writeInternal(data []byte) { // } func (rw *ResponseWriterWrapper) generateSensitiveRespForContent(curChunk types.ChatCompletionChunk) types.ChatCompletionChunk { + var index int64 = 0 + if len(curChunk.Choices) > 0 { + index = curChunk.Choices[0].Index + } newChunk := types.ChatCompletionChunk{ ID: curChunk.ID, Model: curChunk.Model, @@ -156,7 +193,7 @@ func (rw *ResponseWriterWrapper) generateSensitiveRespForContent(curChunk types. Content: "The message includes inappropriate content and has been blocked. We appreciate your understanding and cooperation.", }, FinishReason: "sensitive", - Index: curChunk.Choices[0].Index, + Index: index, }, }, SystemFingerprint: curChunk.SystemFingerprint, @@ -181,6 +218,39 @@ func generateSensitiveRespForPrompt() types.ChatCompletionChunk { return newChunk } +func handleSensitiveResponse(c *gin.Context, stream bool, checkResult *rpc.CheckResult) { + slog.DebugContext( + c.Request.Context(), + "sensitive content detected", + slog.String("reason", checkResult.Reason), + ) + + resp := generateSensitiveRespForPrompt() + if stream { + writeSensitiveStreamResponse(c, resp) + return + } + writeSensitiveJSONResponse(c, resp) +} + +func writeSensitiveStreamResponse(c *gin.Context, resp any) { + errorChunkJson, err := json.Marshal(resp) + if err != nil { + slog.ErrorContext(c.Request.Context(), "marshal error:", slog.String("err", err.Error())) + c.Status(http.StatusInternalServerError) + return + } + _, err = c.Writer.Write([]byte("data: " + string(errorChunkJson) + "\n\n" + "[DONE]")) + if err != nil { + slog.ErrorContext(c.Request.Context(), "write into resp error:", slog.String("err", err.Error())) + } + c.Writer.Flush() +} + +func writeSensitiveJSONResponse(c *gin.Context, resp any) { + c.JSON(http.StatusOK, resp) +} + func generateInsufficientBalanceResp(frontendURL string) types.ChatCompletionChunk { rechargeURL := fmt.Sprintf("%s/settings/recharge-payment", frontendURL) message := fmt.Sprintf( diff --git a/aigateway/handler/response_writer_wrapper_non_stream.go b/aigateway/handler/response_writer_wrapper_non_stream.go index f998d8f4b..d5b08610b 100644 --- a/aigateway/handler/response_writer_wrapper_non_stream.go +++ b/aigateway/handler/response_writer_wrapper_non_stream.go @@ -10,6 +10,8 @@ import ( "net/http" "time" + "opencsg.com/csghub-server/builder/rpc" + "github.com/gin-gonic/gin" "opencsg.com/csghub-server/aigateway/component" "opencsg.com/csghub-server/aigateway/token" @@ -37,6 +39,7 @@ func (nsw *nonStreamResponseWriter) Header() http.Header { } func (nsw *nonStreamResponseWriter) WriteHeader(statusCode int) { + nsw.internalWritter.Header().Del("Content-Length") nsw.internalWritter.WriteHeader(statusCode) } @@ -93,12 +96,12 @@ func (nsw *nonStreamResponseWriter) nonStreamWrite(originData []byte) (int, erro // Step 6: Perform content moderation if service is available ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - result, err := nsw.moderationComponent.CheckChatNonStreamResponse(ctx, completion) + result, err := nsw.CheckChatNonStreamResponse(ctx, completion) if err != nil { slog.Error("NonStreamResponseWriter nonStreamWrite failed to call moderation service", slog.Any("err", err)) // Continue with original content if moderation service fails - } else if result.IsSensitive { + } else if result != nil && result.IsSensitive { // Replace sensitive content with block message slog.Debug("NonStreamResponseWriter nonStreamWrite checkresult is sensitive", slog.Any("content", completion), @@ -115,6 +118,13 @@ func (nsw *nonStreamResponseWriter) nonStreamWrite(originData []byte) (int, erro return originLen, nsw.writeToInternal(nsw.buffer.Bytes()) } +func (nsw *nonStreamResponseWriter) CheckChatNonStreamResponse(ctx context.Context, completion types.ChatCompletion) (*rpc.CheckResult, error) { + if nsw.moderationComponent == nil { + return nil, nil + } + return nsw.moderationComponent.CheckChatNonStreamResponse(ctx, completion) +} + // writeToInternal encapsulates writing to the internal writer with error logging and buffer cleanup func (nsw *nonStreamResponseWriter) writeToInternal(data []byte) error { _, err := nsw.internalWritter.Write(data) diff --git a/aigateway/handler/response_writer_wrapper_non_stream_test.go b/aigateway/handler/response_writer_wrapper_non_stream_test.go index 25927244f..6f5db8021 100644 --- a/aigateway/handler/response_writer_wrapper_non_stream_test.go +++ b/aigateway/handler/response_writer_wrapper_non_stream_test.go @@ -69,6 +69,31 @@ func TestNonStreamResponseWriter_Write(t *testing.T) { require.Equal(t, len(data), n) }) + t.Run("write with nil moderation component", func(t *testing.T) { + w := httptest.NewRecorder() + w.Header().Set("Content-Encoding", "") + ctx, _ := gin.CreateTestContext(w) + nsw := newNonStreamResponseWriter(ctx.Writer, nil, nil) + + completion := types.ChatCompletion{ + ChatCompletion: openai.ChatCompletion{ + Choices: []openai.ChatCompletionChoice{ + { + Message: openai.ChatCompletionMessage{ + Content: "This is a test response", + }, + }, + }, + }, + } + + data, _ := json.Marshal(completion) + n, err := nsw.Write(data) + + require.NoError(t, err) + require.Equal(t, len(data), n) + }) + t.Run("write with empty choices", func(t *testing.T) { // Prepare test data w := httptest.NewRecorder() diff --git a/aigateway/handler/response_writer_wrapper_stream_test.go b/aigateway/handler/response_writer_wrapper_stream_test.go index 87a55be29..2b7751757 100644 --- a/aigateway/handler/response_writer_wrapper_stream_test.go +++ b/aigateway/handler/response_writer_wrapper_stream_test.go @@ -89,7 +89,10 @@ func TestResponseWriterWrapper_Write_NormalContent(t *testing.T) { func TestResponseWriterWrapper_Write_DoneMessage(t *testing.T) { w := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(w) - rw := newStreamResponseWriter(ctx.Writer, component.NewMockModeration(t), nil) + mockMod := component.NewMockModeration(t) + rw := newStreamResponseWriter(ctx.Writer, mockMod, nil) + + mockMod.EXPECT().CloseStreamCheck(mock.Anything, rw.id).Return(&rpc.CheckResult{IsSensitive: false}, nil) doneData := []byte("data: [DONE]\n\n") _, err := rw.Write(doneData) @@ -216,6 +219,89 @@ func TestResponseWriterWrapper_Write_ModerationServiceError(t *testing.T) { } } +func TestResponseWriterWrapper_Write_NilModerationComponent(t *testing.T) { + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + rw := newStreamResponseWriter(ctx.Writer, nil, nil) + + // Test CheckChatStreamResponse with nil component + normalChunk := types.ChatCompletionChunk{ + ID: "test-id", + Choices: []types.ChatCompletionChunkChoice{ + { + Delta: types.ChatCompletionChunkChoiceDelta{ + Content: "normal content", + }, + }, + }, + } + chunkJSON, _ := json.Marshal(normalChunk) + streamData := []byte("data: " + string(chunkJSON) + "\n\n") + + n, err := rw.Write(streamData) + if err != nil { + t.Errorf("Write should not return error with nil moderation component: %v", err) + } + if n != len(streamData) { + t.Errorf("Expected to write %d bytes, got %d", len(streamData), n) + } + if !bytes.Contains(w.Body.Bytes(), streamData) { + t.Error("Original data should be written with nil moderation component") + } + + // Test CloseStreamCheck with nil component + doneData := []byte("data: [DONE]\n\n") + n, err = rw.Write(doneData) + if err != nil { + t.Errorf("Write should not return error for DONE message with nil moderation component: %v", err) + } + if n != len(doneData) { + t.Errorf("Expected to write %d bytes, got %d", len(doneData), n) + } + if !bytes.Contains(w.Body.Bytes(), doneData) { + t.Error("DONE message should be written with nil moderation component") + } +} + +func TestResponseWriterWrapper_Write_DoneMessageSensitive(t *testing.T) { + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + mockMod := component.NewMockModeration(t) + rw := newStreamResponseWriter(ctx.Writer, mockMod, nil) + + mockMod.EXPECT().CloseStreamCheck(mock.Anything, rw.id).Return(&rpc.CheckResult{IsSensitive: true, Reason: "sensitive done"}, nil) + + doneData := []byte("data: [DONE]\n\n") + _, err := rw.Write(doneData) + if !errors.Is(err, ErrSensitiveContent) { + t.Errorf("Write should return ErrSensitiveContent when done check is sensitive, got: %v", err) + } + + responseBody := w.Body.String() + if !bytes.Contains([]byte(responseBody), []byte("The message includes inappropriate content")) { + t.Error("Response should include sensitive content warning") + } +} + +func TestResponseWriterWrapper_Write_DoneMessageError(t *testing.T) { + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + mockMod := component.NewMockModeration(t) + rw := newStreamResponseWriter(ctx.Writer, mockMod, nil) + + mockMod.EXPECT().CloseStreamCheck(mock.Anything, rw.id).Return(nil, errors.New("done check error")) + + doneData := []byte("data: [DONE]\n\n") + _, err := rw.Write(doneData) + if err != nil { + t.Errorf("Write should not return error when done check fails, got: %v", err) + } + + if !bytes.Contains(w.Body.Bytes(), doneData) { + t.Error("Response should include original done message when check fails") + } +} + func TestResponseWriterWrapper_Write_InvalidJSON(t *testing.T) { w := httptest.NewRecorder() ctx, _ := gin.CreateTestContext(w) diff --git a/aigateway/handler/response_writer_wrapper_test.go b/aigateway/handler/response_writer_wrapper_test.go index d059befd3..0ddfdf228 100644 --- a/aigateway/handler/response_writer_wrapper_test.go +++ b/aigateway/handler/response_writer_wrapper_test.go @@ -61,6 +61,80 @@ func TestResponseWriterWrapper_StreamWrite(t *testing.T) { } } +func TestGenerateInsufficientBalanceResp(t *testing.T) { + frontendURL := "http://localhost:8080" + chunk := generateInsufficientBalanceResp(frontendURL) + assert.Len(t, chunk.Choices, 1) + assert.Equal(t, "insufficient_balance", chunk.Choices[0].FinishReason) + assert.Contains(t, chunk.Choices[0].Delta.Content, "**Insufficient balance**") + assert.Contains(t, chunk.Choices[0].Delta.Content, frontendURL+"/settings/recharge-payment") +} + +func TestWriteSensitiveStreamResponse(t *testing.T) { + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Request = httptest.NewRequest("GET", "/", nil) + resp := generateSensitiveRespForPrompt() + + writeSensitiveStreamResponse(ctx, resp) + + assert.Equal(t, 200, w.Code) + body := w.Body.String() + assert.Contains(t, body, "data: {") + assert.Contains(t, body, "sensitive") + assert.Contains(t, body, "[DONE]") +} + +func TestWriteSensitiveJSONResponse(t *testing.T) { + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Request = httptest.NewRequest("GET", "/", nil) + resp := generateSensitiveRespForPrompt() + + writeSensitiveJSONResponse(ctx, resp) + + assert.Equal(t, 200, w.Code) + body := w.Body.String() + assert.Contains(t, body, "sensitive") + var parsedResp openai.ChatCompletionChunk + err := json.Unmarshal([]byte(body), &parsedResp) + assert.NoError(t, err) + assert.Equal(t, "sensitive", parsedResp.Choices[0].FinishReason) +} + +func TestHandleSensitiveResponse(t *testing.T) { + t.Run("stream", func(t *testing.T) { + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Request = httptest.NewRequest("GET", "/", nil) + + checkResult := &rpc.CheckResult{Reason: "test reason"} + handleSensitiveResponse(ctx, true, checkResult) + + assert.Equal(t, 200, w.Code) + body := w.Body.String() + assert.Contains(t, body, "data: {") + assert.Contains(t, body, "sensitive") + assert.Contains(t, body, "[DONE]") + }) + + t.Run("json", func(t *testing.T) { + w := httptest.NewRecorder() + ctx, _ := gin.CreateTestContext(w) + ctx.Request = httptest.NewRequest("GET", "/", nil) + + checkResult := &rpc.CheckResult{Reason: "test reason"} + handleSensitiveResponse(ctx, false, checkResult) + + assert.Equal(t, 200, w.Code) + body := w.Body.String() + var parsedResp openai.ChatCompletionChunk + err := json.Unmarshal([]byte(body), &parsedResp) + assert.NoError(t, err) + assert.Equal(t, "sensitive", parsedResp.Choices[0].FinishReason) + }) +} + func TestResponseWriterWrapper_StreamWrite_WithWhiteSpace(t *testing.T) { chunk := openai.ChatCompletionChunk{ ID: "test-id", diff --git a/aigateway/types/chat_completion.go b/aigateway/types/chat_completion.go index 7c8dcad65..dbc4dc8be 100644 --- a/aigateway/types/chat_completion.go +++ b/aigateway/types/chat_completion.go @@ -67,3 +67,11 @@ type ChatCompletionChunkChoiceDelta struct { type ChatCompletion struct { openai.ChatCompletion } + +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" + RoleSystem Role = "system" +) diff --git a/aigateway/types/openai.go b/aigateway/types/openai.go index 93426aa6e..adda10efa 100644 --- a/aigateway/types/openai.go +++ b/aigateway/types/openai.go @@ -36,6 +36,10 @@ type InternalModelInfo struct { type ExternalModelInfo struct { Provider string `json:"-"` // external provider name, like openai, anthropic etc AuthHead string `json:"-"` // the auth header to access the external model + // NeedSensitiveCheck controls whether requests for this model should go + // through sensitive content detection in aigateway. Set to false to skip + // the check (e.g. for guard models or trusted internal models). + NeedSensitiveCheck bool `json:"-"` } type Model struct { @@ -65,17 +69,19 @@ func (m Model) MarshalJSON() ([]byte, error) { ImageID *string `json:"image_id,omitempty"` AuthHead *string `json:"auth_head,omitempty"` Provider *string `json:"provider,omitempty"` + NeedSensitiveCheck bool `json:"need_sensitive_check"` } resp := internalModelResponse{ - ID: m.ID, - Object: m.Object, - Created: m.Created, - OwnedBy: m.OwnedBy, - Task: m.Task, - DisplayName: m.DisplayName, - Public: m.Public, - Endpoint: m.Endpoint, - Metadata: m.Metadata, + ID: m.ID, + Object: m.Object, + Created: m.Created, + OwnedBy: m.OwnedBy, + Task: m.Task, + DisplayName: m.DisplayName, + Public: m.Public, + Endpoint: m.Endpoint, + Metadata: m.Metadata, + NeedSensitiveCheck: m.NeedSensitiveCheck, } if m.SupportFunctionCall { @@ -121,6 +127,7 @@ func (m *Model) UnmarshalJSON(data []byte) error { ImageID string `json:"image_id,omitempty"` AuthHead string `json:"auth_head,omitempty"` Provider string `json:"provider,omitempty"` + NeedSensitiveCheck bool `json:"need_sensitive_check"` } var aux internalModelResponse if err := json.Unmarshal(data, &aux); err != nil { @@ -141,6 +148,7 @@ func (m *Model) UnmarshalJSON(data []byte) error { m.ImageID = aux.ImageID m.AuthHead = aux.AuthHead m.Provider = aux.Provider + m.NeedSensitiveCheck = aux.NeedSensitiveCheck return nil } @@ -156,6 +164,19 @@ func (m *Model) ForExternalResponse() *Model { return m } +// SkipBalance set the model for skip balance mode +func (m *Model) SkipBalance() bool { + // MetaTaskKey values is array of strings, check if MetaTaskValGuard is in it + if tasks, ok := m.Metadata[MetaTaskKey].([]interface{}); ok { + for _, t := range tasks { + if task, ok := t.(string); ok && task == MetaTaskValGuard { + return true + } + } + } + return false +} + // ModelList represents the model list response type ModelList struct { Object string `json:"object"` @@ -188,7 +209,9 @@ type UserPreferenceRequest struct { const OpenCSGAppNameHeader string = "OpenCSG-App-Name" const ( - AgenticHubApp = "Agentichub" + AgenticHubApp = "Agentichub" + MetaTaskKey = "task" + MetaTaskValGuard = "guard" ) // ModelSource represents the source of a model diff --git a/aigateway/types/openai_test.go b/aigateway/types/openai_test.go index 3ceba80b2..8fcae8a99 100644 --- a/aigateway/types/openai_test.go +++ b/aigateway/types/openai_test.go @@ -2,6 +2,7 @@ package types import ( "encoding/json" + "github.com/stretchr/testify/require" "strings" "testing" ) @@ -169,3 +170,59 @@ func TestModelUnmarshal(t *testing.T) { t.Errorf("Model list unmarshal failed, got: %v", modelList) } } + +func TestModel_SkipBalance(t *testing.T) { + tests := []struct { + name string + metadata map[string]any + expected bool + }{ + { + name: "Metadata is nil", + metadata: nil, + expected: false, + }, + { + name: "Metadata does not have MetaTaskKey", + metadata: map[string]any{}, + expected: false, + }, + { + name: "MetaTaskKey value is not a slice", + metadata: map[string]any{MetaTaskKey: "not a slice"}, + expected: false, + }, + { + name: "MetaTaskKey value is slice but not of strings", + metadata: map[string]any{MetaTaskKey: []int{1, 2, 3}}, + expected: false, + }, + { + name: "MetaTaskKey value is slice of strings but does not contain MetaTaskValGuard", + metadata: map[string]any{MetaTaskKey: []interface{}{"text-generation", "text-to-image"}}, + expected: false, + }, + { + name: "MetaTaskKey value is slice of strings and contains MetaTaskValGuard", + metadata: map[string]any{MetaTaskKey: []interface{}{"text-generation", MetaTaskValGuard}}, + expected: true, + }, + { + name: "MetaTaskKey value is slice of mixed types with MetaTaskValGuard", + metadata: map[string]any{MetaTaskKey: []interface{}{1, "text-generation", MetaTaskValGuard, 3.14}}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model := &Model{ + BaseModel: BaseModel{ + Metadata: tt.metadata, + }, + } + result := model.SkipBalance() + require.Equal(t, tt.expected, result) + }) + } +} diff --git a/builder/llm/client.go b/builder/llm/client.go index cfcfe12cb..ab75197d3 100644 --- a/builder/llm/client.go +++ b/builder/llm/client.go @@ -9,6 +9,7 @@ import ( "io" "log/slog" "net/http" + "opencsg.com/csghub-server/builder/rpc" "strings" "opencsg.com/csghub-server/common/types" @@ -19,12 +20,12 @@ type LLMSvcClient interface { } type Client struct { - client *http.Client + client rpc.HttpDoer } func NewClient() *Client { return &Client{ - client: http.DefaultClient, + client: rpc.NewHttpClient("").WithRetry(2), } } diff --git a/builder/rpc/moderation_svc_client.go b/builder/rpc/moderation_svc_client.go index 7b12e7396..2714c405c 100644 --- a/builder/rpc/moderation_svc_client.go +++ b/builder/rpc/moderation_svc_client.go @@ -13,8 +13,8 @@ type ModerationSvcClient interface { PassTextCheck(ctx context.Context, scenario types.SensitiveScenario, text string) (*CheckResult, error) PassImageCheck(ctx context.Context, scenario types.SensitiveScenario, ossBucketName, ossObjectName string) (*CheckResult, error) PassImageURLCheck(ctx context.Context, scenario types.SensitiveScenario, imageURL string) (*CheckResult, error) - PassLLMRespCheck(ctx context.Context, text, sessionId string) (*CheckResult, error) - PassLLMPromptCheck(ctx context.Context, text, accountId string) (*CheckResult, error) + PassLLMRespCheck(ctx context.Context, req types.LLMCheckRequest) (*CheckResult, error) + PassLLMPromptCheck(ctx context.Context, req types.LLMCheckRequest) (*CheckResult, error) SubmitRepoCheck(ctx context.Context, repoType types.RepositoryType, namespace, name string) error } @@ -58,22 +58,8 @@ func (c *ModerationSvcHttpClient) PassTextCheck(ctx context.Context, scenario ty } // If sessionID is set, used to check stream response; if not set, check non-stream. -func (c *ModerationSvcHttpClient) PassLLMRespCheck(ctx context.Context, text, sessionId string) (*CheckResult, error) { - type ServiceParameters struct { - Content string `json:"content"` - SessionId string `json:"sessionId"` - } - type CheckRequest struct { - Service string `json:"Service"` - ServiceParameters ServiceParameters `json:"ServiceParameters"` - } - req := &CheckRequest{ - Service: string(types.ScenarioLLMResModeration), - ServiceParameters: ServiceParameters{ - Content: text, - SessionId: sessionId, - }, - } +func (c *ModerationSvcHttpClient) PassLLMRespCheck(ctx context.Context, req types.LLMCheckRequest) (*CheckResult, error) { + req.Scenario = types.ScenarioLLMResModeration const path = "/api/v1/llmresp" var resp httpbase.R resp.Data = &CheckResult{} @@ -163,22 +149,8 @@ func (c *ModerationSvcHttpClient) SubmitRepoCheck(ctx context.Context, repoType return nil } -func (c *ModerationSvcHttpClient) PassLLMPromptCheck(ctx context.Context, text, accountId string) (*CheckResult, error) { - type ServiceParameters struct { - Content string `json:"content"` - SessionId string `json:"sessionId"` - } - type CheckRequest struct { - Service string `json:"Service"` - ServiceParameters ServiceParameters `json:"ServiceParameters"` - } - req := &CheckRequest{ - Service: string(types.ScenarioLLMQueryModeration), - ServiceParameters: ServiceParameters{ - Content: text, - SessionId: accountId, - }, - } +func (c *ModerationSvcHttpClient) PassLLMPromptCheck(ctx context.Context, req types.LLMCheckRequest) (*CheckResult, error) { + req.Scenario = types.ScenarioLLMQueryModeration const path = "/api/v1/llmprompt" var resp httpbase.R resp.Data = &CheckResult{} diff --git a/builder/rpc/moderation_svc_client_test.go b/builder/rpc/moderation_svc_client_test.go index 7473ee3ce..259461271 100644 --- a/builder/rpc/moderation_svc_client_test.go +++ b/builder/rpc/moderation_svc_client_test.go @@ -104,18 +104,12 @@ func TestModerationSvcHttpClient_PassLLMRespCheck(t *testing.T) { assert.Equal(t, "/api/v1/llmresp", r.URL.Path) assert.Equal(t, http.MethodPost, r.Method) - var req struct { - Service string `json:"Service"` - ServiceParameters struct { - Content string `json:"content"` - SessionId string `json:"sessionId"` - } `json:"ServiceParameters"` - } + var req types.LLMCheckRequest err := json.NewDecoder(r.Body).Decode(&req) assert.NoError(t, err) - assert.Equal(t, string(types.ScenarioLLMResModeration), req.Service) - assert.Equal(t, "test_text", req.ServiceParameters.Content) - assert.Equal(t, "test_session", req.ServiceParameters.SessionId) + assert.Equal(t, types.ScenarioLLMResModeration, req.Scenario) + assert.Equal(t, "test_text", req.Text) + assert.Equal(t, "test_session", req.SessionId) resp := httpbase.R{ Data: CheckResult{ @@ -137,7 +131,11 @@ func TestModerationSvcHttpClient_PassLLMRespCheck(t *testing.T) { client := &ModerationSvcHttpClient{ hc: hc, } - res, err := client.PassLLMRespCheck(context.Background(), "test_text", "test_session") + res, err := client.PassLLMRespCheck(context.Background(), types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "test_text", + SessionId: "test_session", + }) assert.NoError(t, err) assert.NotNil(t, res) assert.True(t, res.IsSensitive) @@ -149,18 +147,12 @@ func TestModerationSvcHttpClient_PassLLMPromptCheck(t *testing.T) { assert.Equal(t, "/api/v1/llmprompt", r.URL.Path) assert.Equal(t, http.MethodPost, r.Method) - var req struct { - Service string `json:"Service"` - ServiceParameters struct { - Content string `json:"content"` - SessionId string `json:"sessionId"` - } `json:"ServiceParameters"` - } + var req types.LLMCheckRequest err := json.NewDecoder(r.Body).Decode(&req) assert.NoError(t, err) - assert.Equal(t, "llm_query_moderation", req.Service) - assert.Equal(t, "test_prompt", req.ServiceParameters.Content) - assert.Equal(t, "test_account", req.ServiceParameters.SessionId) + assert.Equal(t, types.ScenarioLLMQueryModeration, req.Scenario) + assert.Equal(t, "test_prompt", req.Text) + assert.Equal(t, "test_account", req.AccountId) resp := httpbase.R{ Data: CheckResult{ @@ -182,7 +174,12 @@ func TestModerationSvcHttpClient_PassLLMPromptCheck(t *testing.T) { client := &ModerationSvcHttpClient{ hc: hc, } - res, err := client.PassLLMPromptCheck(context.Background(), "test_prompt", "test_account") + res, err := client.PassLLMPromptCheck(context.Background(), types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "test_prompt", + AccountId: "test_account", + }) + assert.NoError(t, err) assert.NotNil(t, res) assert.False(t, res.IsSensitive) diff --git a/builder/sensitive/aho_corasick.go b/builder/sensitive/aho_corasick.go index 35d2933d8..f6c37f47b 100644 --- a/builder/sensitive/aho_corasick.go +++ b/builder/sensitive/aho_corasick.go @@ -77,14 +77,14 @@ func (iac *ACAutomation) PassImageURLCheck(ctx context.Context, scenario types.S } // PassLLMCheck implements the SensitiveChecker interface for ImmutableAC -func (iac *ACAutomation) PassLLMCheck(ctx context.Context, scenario types.SensitiveScenario, text string, sessionId string, accountId string) (*CheckResult, error) { - if scenario != types.ScenarioLLMQueryModeration && scenario != types.ScenarioLLMResModeration { - slog.WarnContext(ctx, "PassLLMCheck received unsupported scenario", slog.String("scenario", string(scenario))) +func (iac *ACAutomation) PassLLMCheck(ctx context.Context, req *types.LLMCheckRequest) (*CheckResult, error) { + if req.Scenario != types.ScenarioLLMQueryModeration && req.Scenario != types.ScenarioLLMResModeration { + slog.WarnContext(ctx, "PassLLMCheck received unsupported scenario", slog.String("scenario", string(req.Scenario))) return &CheckResult{ IsSensitive: false, }, nil } - detectResult := iac.detect(text) + detectResult := iac.detect(req.Text) if detectResult != nil { slog.InfoContext(ctx, "ACAutomation PassLLMCheck detected sensitive word", slog.String("reason", *detectResult.Reason)) diff --git a/builder/sensitive/aho_corasick_test.go b/builder/sensitive/aho_corasick_test.go index 377958d25..242aac6bc 100644 --- a/builder/sensitive/aho_corasick_test.go +++ b/builder/sensitive/aho_corasick_test.go @@ -77,22 +77,34 @@ func TestAC_PassLLMCheck(t *testing.T) { ctx := context.Background() // Test case 1: LLM query with sensitive word - result, err := checker.PassLLMCheck(ctx, types.ScenarioLLMQueryModeration, "test word1", "", "") + result, err := checker.PassLLMCheck(ctx, &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "test word1", + }) assert.NoError(t, err) assert.True(t, result.IsSensitive) // Test case 2: LLM response with sensitive word - result, err = checker.PassLLMCheck(ctx, types.ScenarioLLMResModeration, "test word2", "", "") + result, err = checker.PassLLMCheck(ctx, &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "test word2", + }) assert.NoError(t, err) assert.True(t, result.IsSensitive) // Test case 3: LLM query with no sensitive words - result, err = checker.PassLLMCheck(ctx, types.ScenarioLLMQueryModeration, "test word", "", "") + result, err = checker.PassLLMCheck(ctx, &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "test word", + }) assert.NoError(t, err) assert.False(t, result.IsSensitive) // Test case 4: Unsupported scenario - result, err = checker.PassLLMCheck(ctx, "unsupported", "test word3", "", "") + result, err = checker.PassLLMCheck(ctx, &types.LLMCheckRequest{ + Scenario: types.SensitiveScenario("unsupported"), + Text: "test word3", + }) assert.NoError(t, err) assert.False(t, result.IsSensitive) } diff --git a/builder/sensitive/aliyun_green.go b/builder/sensitive/aliyun_green.go index 445040e15..2918b616f 100644 --- a/builder/sensitive/aliyun_green.go +++ b/builder/sensitive/aliyun_green.go @@ -237,39 +237,40 @@ func (*AliyunGreenChecker) SplitTasks(text string) []map[string]string { return tasks } -func (c *AliyunGreenChecker) PassLLMCheck(ctx context.Context, scenario types.SensitiveScenario, text string, sessionId string, accountId string) (*CheckResult, error) { +func (c *AliyunGreenChecker) PassLLMCheck(ctx context.Context, req *types.LLMCheckRequest) (*CheckResult, error) { // Build parameter map paramMap := map[string]interface{}{ - "content": text, + "content": req.Text, } // Add different ID field based on idType - if sessionId != "" && accountId != "" { + if req.SessionId != "" && req.AccountId != "" { return nil, fmt.Errorf("fail to call aliyun TextModerationPlusWithOptions, can't set sessionId and accountId both") } - if sessionId != "" { - paramMap["sessionId"] = sessionId + if req.SessionId != "" { + paramMap["sessionId"] = req.SessionId } - if accountId != "" { - if text == "" { + if req.AccountId != "" { + if req.Text == "" { return &CheckResult{IsSensitive: false}, nil } - paramMap["accountId"] = accountId + paramMap["accountId"] = req.AccountId } serviceParameters, _ := json.Marshal(paramMap) - req := &green20220302.TextModerationPlusRequest{ - Service: tea.String(string(scenario)), + request := &green20220302.TextModerationPlusRequest{ + Service: tea.String(string(req.Scenario)), ServiceParameters: tea.String(string(serviceParameters)), } + options := &util.RuntimeOptions{ ReadTimeout: tea.Int(500), ConnectTimeout: tea.Int(500), } - resp, err := c.green2022.TextModerationPlusWithOptions(req, options) + resp, err := c.green2022.TextModerationPlusWithOptions(request, options) if err != nil { - slog.Error("fail to call aliyun TextModerationPlusWithOptions", slog.String("content", text), slog.Any("error", err)) - return nil, err + slog.Error("fail to call aliyun TextModerationPlusWithOptions", slog.String("content", req.Text), slog.Any("error", err)) + return nil, fmt.Errorf("fail to call aliyun TextModerationPlusWithOptions: %w", err) } if *resp.StatusCode != http.StatusOK { @@ -291,11 +292,10 @@ func (c *AliyunGreenChecker) PassLLMCheck(ctx context.Context, scenario types.Se if !strings.Contains(*result.Label, "political") { continue } - slog.Info("sensitive content detected", slog.String("content", text), slog.String("reason", *result.RiskWords), + slog.Info("sensitive content detected", slog.String("content", req.Text), slog.String("reason", *result.RiskWords), slog.String("label", *result.Label), slog.String("aliyun_request_id", *resp.Body.RequestId)) return &CheckResult{IsSensitive: true, Reason: fmt.Sprintf("label:%s,reason:%s,requestId:%s", *result.Label, *result.RiskWords, *resp.Body.RequestId)}, nil } - return &CheckResult{IsSensitive: false}, nil } diff --git a/builder/sensitive/aliyun_green_test.go b/builder/sensitive/aliyun_green_test.go index 75f2c22d1..c4ff428a7 100644 --- a/builder/sensitive/aliyun_green_test.go +++ b/builder/sensitive/aliyun_green_test.go @@ -348,7 +348,11 @@ func TestSensitiveChecker_PassLLMCheck(t *testing.T) { }, }, }, nil).Once() - result, err := checker.PassLLMCheck(context.Background(), "foo", "foo", id, "") + result, err := checker.PassLLMCheck(context.Background(), &types.LLMCheckRequest{ + Scenario: types.SensitiveScenario("foo"), + Text: "foo", + SessionId: id, + }) require.Nil(t, err) require.Equal(t, c.isSensitive, result.IsSensitive) require.Equal(t, c.wantReason, result.Reason) @@ -382,7 +386,11 @@ func TestSensitiveChecker_PassLLMCheck(t *testing.T) { }, }, }, nil).Once() - result, err := checker.PassLLMCheck(context.Background(), "foo", "foo", "", id) + result, err := checker.PassLLMCheck(context.Background(), &types.LLMCheckRequest{ + Scenario: types.SensitiveScenario("foo"), + Text: "foo", + AccountId: id, + }) require.Nil(t, err) require.Equal(t, c.isSensitive, result.IsSensitive) require.Equal(t, c.wantReason, result.Reason) diff --git a/builder/sensitive/chain.go b/builder/sensitive/chain.go index 6ca87deb3..32044f616 100644 --- a/builder/sensitive/chain.go +++ b/builder/sensitive/chain.go @@ -52,6 +52,18 @@ func WithMutableACAutomaton(loader internal.Loader) ChainOption { } } +// WithOpenAILLMChecker adds an OpenAI LLM sensitive checker to the chain +func WithOpenAILLMChecker() ChainOption { + return func(config *config.Config, c *chainImpl) { + if config.SensitiveCheck.LLM.Enable { + checker := NewOpenAILLMChecker(config) + c.checkers = append(c.checkers, checker) + } else { + slog.Warn("sensitive config for LLM moderation service not enabled") + } + } +} + // NewChainChecker create a chain sensitive checker // // It will run all checkers in order by the options provided @@ -111,9 +123,9 @@ func (c *chainImpl) PassImageURLCheck(ctx context.Context, scenario types.Sensit return &CheckResult{IsSensitive: false}, nil } -func (c *chainImpl) PassLLMCheck(ctx context.Context, scenario types.SensitiveScenario, text string, sessionId string, accountId string) (*CheckResult, error) { +func (c *chainImpl) PassLLMCheck(ctx context.Context, req *types.LLMCheckRequest) (*CheckResult, error) { for _, checker := range c.checkers { - res, err := checker.PassLLMCheck(ctx, scenario, text, sessionId, accountId) + res, err := checker.PassLLMCheck(ctx, req) if err != nil { return nil, err } diff --git a/builder/sensitive/chain_test.go b/builder/sensitive/chain_test.go index 3fc9855a8..b5cfa4188 100644 --- a/builder/sensitive/chain_test.go +++ b/builder/sensitive/chain_test.go @@ -283,7 +283,11 @@ func TestChainImpl_AliYun_PassLLMCheck(t *testing.T) { }, }, nil) - result, err := chain.PassLLMCheck(ctx, scenario, text, sessionId, "") + result, err := chain.PassLLMCheck(ctx, &types.LLMCheckRequest{ + Scenario: scenario, + Text: text, + SessionId: sessionId, + }) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -338,7 +342,11 @@ func TestChainImpl_AliYun_PassLLMCheck_Sensitive(t *testing.T) { }, }, nil) - result, err := chain.PassLLMCheck(ctx, scenario, text, sessionId, "") + result, err := chain.PassLLMCheck(ctx, &types.LLMCheckRequest{ + Scenario: scenario, + Text: text, + SessionId: sessionId, + }) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/builder/sensitive/guard_llm.go b/builder/sensitive/guard_llm.go new file mode 100644 index 000000000..cbb1e7504 --- /dev/null +++ b/builder/sensitive/guard_llm.go @@ -0,0 +1,148 @@ +package sensitive + +import ( + "context" + "fmt" + "log/slog" + "strings" + "time" + + gwtype "opencsg.com/csghub-server/aigateway/types" + "opencsg.com/csghub-server/builder/llm" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +type OpenAILLMChecker struct { + config *config.Config + llmClient *llm.Client + parser LLMResponseParser +} + +func NewOpenAILLMChecker(cfg *config.Config) *OpenAILLMChecker { + if cfg.SensitiveCheck.LLM.Endpoint == "" { + panic("SensitiveCheck.LLM.Endpoint is empty") + } + if cfg.SensitiveCheck.LLM.GuardStreamModel == "" { + panic("SensitiveCheck.LLM.GuardStreamModel is empty") + } + if cfg.SensitiveCheck.LLM.GuardModel == "" { + panic("SensitiveCheck.LLM.GuardModel is empty") + } + return &OpenAILLMChecker{ + config: cfg, + llmClient: llm.NewClient(), + parser: NewChainParser(cfg.SensitiveCheck.LLM.SafetyRegex), + } +} + +func (c *OpenAILLMChecker) PassTextCheck(ctx context.Context, scenario types.SensitiveScenario, text string) (*CheckResult, error) { + // Chunk text logic + maxChars := c.config.SensitiveCheck.StreamContextCache.MaxChars + if maxChars <= 0 { + maxChars = 2000 + } + + req := &types.LLMCheckRequest{ + Text: text, + MaxTokens: c.config.SensitiveCheck.LLM.MaxTokens, + ModelName: c.config.SensitiveCheck.LLM.GuardModel, + Resumable: false, + Stream: false, + Role: string(gwtype.RoleUser), + } + if len(text) <= maxChars { + return c.doCheck(ctx, req) + } + + // Simple chunking for large text + for i := 0; i < len(text); i += maxChars { + end := i + maxChars + if end > len(text) { + end = len(text) + } + req.Text = text[i:end] + res, err := c.doCheck(ctx, req) + if err != nil { + return nil, err + } + if res.IsSensitive { + return res, nil + } + } + + return &CheckResult{IsSensitive: false}, nil +} + +func (c *OpenAILLMChecker) PassImageCheck(ctx context.Context, scenario types.SensitiveScenario, ossBucketName, ossObjectName string) (*CheckResult, error) { + // Not supported by text LLM, default to pass + return &CheckResult{IsSensitive: false}, nil +} + +func (c *OpenAILLMChecker) PassImageURLCheck(ctx context.Context, scenario types.SensitiveScenario, imageURL string) (*CheckResult, error) { + // Not supported by text LLM, default to pass + return &CheckResult{IsSensitive: false}, nil +} + +func (c *OpenAILLMChecker) PassLLMCheck(ctx context.Context, req *types.LLMCheckRequest) (*CheckResult, error) { + return c.doCheck(ctx, req) +} + +func (c *OpenAILLMChecker) doCheck(ctx context.Context, req *types.LLMCheckRequest) (*CheckResult, error) { + if req.Text == "" { + return &CheckResult{IsSensitive: false}, nil + } + + // API Request + reqBody := types.LLMReqBody{ + Model: req.ModelName, + Messages: []types.LLMMessage{ + {Role: req.Role, Content: req.Text}, + }, + Stream: false, + Temperature: c.config.SensitiveCheck.LLM.Temperature, + MaxTokens: req.MaxTokens, + RawJSON: req.RawJSON, + } + + endpoint := c.config.SensitiveCheck.LLM.Endpoint + headers := make(map[string]string) + headers["x-session-id"] = req.SessionId + headers["x-resumable"] = fmt.Sprintf("%v", req.Resumable) + if c.config.SensitiveCheck.LLM.APIKey != "" { + headers["Authorization"] = "Bearer " + c.config.SensitiveCheck.LLM.APIKey + } + + // Retry mechanism for 429 + var content string + var err error + maxRetries := 3 + for i := 0; i < maxRetries; i++ { + timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(c.config.SensitiveCheck.LLM.TimeoutMS)*time.Millisecond) + content, err = c.llmClient.Chat(timeoutCtx, endpoint, "", headers, reqBody) + cancel() + if err == nil { + break + } + // If not 429, don't retry (assuming llmClient returns "unexpected http status code:429") + if !strings.Contains(err.Error(), "429") { + break + } + // exponential backoff or simple sleep + time.Sleep(100 * time.Millisecond) + } + + if err != nil { + // Check if it's a 429 error (assuming llmClient returns a specific error format or we can check string) + // In our current llm.Client, non-2xx status returns "unexpected http status code:%d" + slog.ErrorContext(ctx, "llm checker api request failed", slog.Any("error", err)) + // Fail-open + return &CheckResult{IsSensitive: false, Reason: "skipped_api_error"}, nil + } + + if content == "" { + return &CheckResult{IsSensitive: false, Reason: "skipped_empty_response"}, nil + } + + return c.parser.Parse(content), nil +} diff --git a/builder/sensitive/guard_llm_more_test.go b/builder/sensitive/guard_llm_more_test.go new file mode 100644 index 000000000..a186c2b78 --- /dev/null +++ b/builder/sensitive/guard_llm_more_test.go @@ -0,0 +1,82 @@ +package sensitive + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func TestOpenAILLMChecker_PassImageAndURLCheck(t *testing.T) { + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.Endpoint = "http://localhost" + cfg.SensitiveCheck.LLM.GuardModel = "test-model" + cfg.SensitiveCheck.LLM.GuardStreamModel = "test-stream-model" + checker := NewOpenAILLMChecker(cfg) + + ctx := context.Background() + res, err := checker.PassImageCheck(ctx, types.ScenarioCommentDetection, "bucket", "obj") + require.NoError(t, err) + require.False(t, res.IsSensitive) + + res, err = checker.PassImageURLCheck(ctx, types.ScenarioCommentDetection, "http://example.com/img.png") + require.NoError(t, err) + require.False(t, res.IsSensitive) +} + +func TestOpenAILLMChecker_PassLLMCheck(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + resp := map[string]interface{}{ + "id": "test-id", + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": `{"is_sensitive": false}`, + }, + }, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.Endpoint = server.URL + cfg.SensitiveCheck.LLM.GuardModel = "test-model" + cfg.SensitiveCheck.LLM.GuardStreamModel = "test-stream-model" + cfg.SensitiveCheck.LLM.TimeoutMS = 1000 + + checker := NewOpenAILLMChecker(cfg) + + ctx := context.Background() + req := &types.LLMCheckRequest{ + Text: "hello world", + MaxTokens: 100, + } + res, err := checker.PassLLMCheck(ctx, req) + require.NoError(t, err) + require.False(t, res.IsSensitive) +} + +func TestOpenAILLMChecker_doCheckEmptyText(t *testing.T) { + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.Endpoint = "http://localhost" + cfg.SensitiveCheck.LLM.GuardModel = "test-model" + cfg.SensitiveCheck.LLM.GuardStreamModel = "test-stream-model" + checker := NewOpenAILLMChecker(cfg) + + ctx := context.Background() + req := &types.LLMCheckRequest{ + Text: "", + ModelName: "test-model", + } + res, err := checker.doCheck(ctx, req) + require.NoError(t, err) + require.False(t, res.IsSensitive) +} diff --git a/builder/sensitive/guard_llm_resp.go b/builder/sensitive/guard_llm_resp.go new file mode 100644 index 000000000..c93680cdc --- /dev/null +++ b/builder/sensitive/guard_llm_resp.go @@ -0,0 +1,106 @@ +package sensitive + +import ( + "encoding/json" + "regexp" + "strings" +) + +type RiskLevel string + +const ( + RiskLevelSafe RiskLevel = "Safe" + RiskLevelUnsafe RiskLevel = "Unsafe" + RiskLevelControversial RiskLevel = "Controversial" +) +const ( + SafetyRegex = `Safety:\s*(Safe|Unsafe|Controversial)` +) + +type LLMResponseParser interface { + Parse(content string) *CheckResult +} + +type ChainParser struct { + parsers []LLMResponseParser +} + +func NewChainParser(safetyRegex string) LLMResponseParser { + return &ChainParser{ + parsers: []LLMResponseParser{ + &QwenGuardRegexParser{SafetyRegex: safetyRegex}, + &JSONParser{}, + }, + } +} + +func (c *ChainParser) Parse(content string) *CheckResult { + for _, parser := range c.parsers { + res := parser.Parse(content) + if res != nil { + return res + } + } + return &CheckResult{IsSensitive: false} +} + +// QwenGuardRegexParser implements the parsing logic for Qwen3Guard model format +type QwenGuardRegexParser struct { + SafetyRegex string +} + +func (p *QwenGuardRegexParser) Parse(content string) *CheckResult { + safetyRegex := p.SafetyRegex + if safetyRegex == "" { + safetyRegex = SafetyRegex + } + safePattern := regexp.MustCompile(safetyRegex) + + safeMatch := safePattern.FindStringSubmatch(content) + if len(safeMatch) < 2 { + return nil // Not matched, try next parser + } + + label := safeMatch[1] + + // If it's safe, return early + if label != string(RiskLevelUnsafe) { + return &CheckResult{IsSensitive: false} + } + + return &CheckResult{IsSensitive: true, Reason: content} +} + +// JSONParser tries to parse standard JSON response +type JSONParser struct{} + +func (p *JSONParser) Parse(content string) *CheckResult { + content = strings.TrimSpace(content) + + // Remove markdown code block if present + if strings.HasPrefix(content, "```json") { + content = strings.TrimPrefix(content, "```json") + content = strings.TrimSuffix(content, "```") + content = strings.TrimSpace(content) + } + + var result LLMCheckResult + err := json.Unmarshal([]byte(content), &result) + if err == nil { + if result.IsSensitive() { + return &CheckResult{IsSensitive: true, Reason: content} + } + return &CheckResult{IsSensitive: false} + } + + return nil // Not matched, try next parser +} + +type LLMCheckResult struct { + RiskLevel string `json:"risk_level"` + CategoryLabels string `json:"category_labels"` +} + +func (p *LLMCheckResult) IsSensitive() bool { + return p.RiskLevel == string(RiskLevelUnsafe) +} diff --git a/builder/sensitive/guard_llm_test.go b/builder/sensitive/guard_llm_test.go new file mode 100644 index 000000000..7ffcad5d8 --- /dev/null +++ b/builder/sensitive/guard_llm_test.go @@ -0,0 +1,216 @@ +package sensitive + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/common/config" + "opencsg.com/csghub-server/common/types" +) + +func TestOpenAILLMChecker_PassTextCheck(t *testing.T) { + // Mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + require.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) + + var reqBody types.LLMReqBody + err := json.NewDecoder(r.Body).Decode(&reqBody) + require.NoError(t, err) + + // Mock sensitive response + if strings.Contains(reqBody.Messages[0].Content, "bad word") { + w.WriteHeader(http.StatusOK) + resp := map[string]interface{}{ + "id": "test-id", + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": `{"risk_level": "Unsafe", "category_labels": "politics"}`, + }, + }, + }, + } + _ = json.NewEncoder(w).Encode(resp) + return + } + + // Mock normal response + w.WriteHeader(http.StatusOK) + resp := map[string]interface{}{ + "id": "test-id", + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": `{"risk_level": "Safe"}`, + }, + }, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.Enable = true + cfg.SensitiveCheck.LLM.Endpoint = server.URL + cfg.SensitiveCheck.LLM.APIKey = "test-key" + cfg.SensitiveCheck.LLM.GuardModel = "test-model" + cfg.SensitiveCheck.LLM.GuardStreamModel = "test-stream-model" + cfg.SensitiveCheck.LLM.TimeoutMS = 1000 + + checker := NewOpenAILLMChecker(cfg) + + ctx := context.Background() + res, err := checker.PassTextCheck(ctx, types.ScenarioCommentDetection, "hello world") + require.NoError(t, err) + require.False(t, res.IsSensitive) + + res, err = checker.PassTextCheck(ctx, types.ScenarioCommentDetection, "bad word") + require.NoError(t, err) + require.True(t, res.IsSensitive) + require.Equal(t, `{"risk_level": "Unsafe", "category_labels": "politics"}`, res.Reason) +} + +func TestOpenAILLMChecker_RetryOn429(t *testing.T) { + requests := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + if requests < 3 { + w.WriteHeader(http.StatusTooManyRequests) + return + } + // Success on 3rd try + w.WriteHeader(http.StatusOK) + resp := map[string]interface{}{ + "id": "test-id", + "choices": []map[string]interface{}{ + { + "message": map[string]interface{}{ + "content": `{"risk_level": "Safe"}`, + }, + }, + }, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.Enable = true + cfg.SensitiveCheck.LLM.Endpoint = server.URL + cfg.SensitiveCheck.LLM.GuardModel = "test-model" + cfg.SensitiveCheck.LLM.GuardStreamModel = "test-stream-model" + cfg.SensitiveCheck.LLM.TimeoutMS = 1000 + + checker := NewOpenAILLMChecker(cfg) + + start := time.Now() + res, err := checker.PassTextCheck(context.Background(), types.ScenarioCommentDetection, "test") + duration := time.Since(start) + + require.NoError(t, err) + require.False(t, res.IsSensitive) + require.Equal(t, 3, requests) // Should retry 2 times, total 3 requests + require.True(t, duration >= 200*time.Millisecond) // 100ms + 200ms sleep +} + +func TestOpenAILLMChecker_ChunkedTextCheck(t *testing.T) { + var receivedTexts []string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var reqBody types.LLMReqBody + _ = json.NewDecoder(r.Body).Decode(&reqBody) + receivedTexts = append(receivedTexts, reqBody.Messages[0].Content) + + w.WriteHeader(http.StatusOK) + resp := map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "content": `{"risk_level": "Safe"}`, + }, + }}, + } + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.Endpoint = server.URL + cfg.SensitiveCheck.LLM.GuardModel = "test-model" + cfg.SensitiveCheck.LLM.GuardStreamModel = "test-stream-model" + cfg.SensitiveCheck.LLM.TimeoutMS = 1000 + cfg.SensitiveCheck.StreamContextCache.MaxChars = 10 // chunk size 10 + + checker := NewOpenAILLMChecker(cfg) + + // Text length 25 should be split into 3 chunks: 10, 10, 5 + text := "1234567890123456789012345" + res, err := checker.PassTextCheck(context.Background(), types.ScenarioCommentDetection, text) + require.NoError(t, err) + require.False(t, res.IsSensitive) + + require.Equal(t, 3, len(receivedTexts)) + require.Equal(t, "1234567890", receivedTexts[0]) + require.Equal(t, "1234567890", receivedTexts[1]) + require.Equal(t, "12345", receivedTexts[2]) +} + +func TestParseLLMResponse(t *testing.T) { + parser := NewChainParser(SafetyRegex) + + tests := []struct { + name string + content string + expected *CheckResult + }{ + { + name: "QwenGuard Safe", + content: "Safety: Safe\nCategories: None", + expected: &CheckResult{IsSensitive: false}, + }, + { + name: "QwenGuard Unsafe with categories", + content: "Safety: Unsafe\nCategories: Violent\nCategories: PII", + expected: &CheckResult{IsSensitive: true, Reason: "Safety: Unsafe\nCategories: Violent\nCategories: PII"}, + }, + { + name: "valid json sensitive", + content: `{"risk_level": "Unsafe", "category_labels": "porn"}`, + expected: &CheckResult{IsSensitive: true, Reason: `{"risk_level": "Unsafe", "category_labels": "porn"}`}, + }, + { + name: "valid json non-sensitive", + content: `{"risk_level": "Safe"}`, + expected: &CheckResult{IsSensitive: false}, + }, + { + name: "markdown json block sensitive", + content: "```json\n{\"risk_level\": \"Unsafe\", \"category_labels\": \"politics\"}\n```", + expected: &CheckResult{IsSensitive: true, Reason: `{"risk_level": "Unsafe", "category_labels": "politics"}`}, + }, + { + name: "fallback text sensitive 1", + content: "The content violates rules. risk_level: Unsafe", + expected: &CheckResult{IsSensitive: false}, + }, + { + name: "fallback text non-sensitive", + content: "The content is safe. risk_level=Safe", + expected: &CheckResult{IsSensitive: false}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res := parser.Parse(tt.content) + require.Equal(t, tt.expected.IsSensitive, res.IsSensitive) + require.Equal(t, tt.expected.Reason, res.Reason) + }) + } +} diff --git a/builder/sensitive/internal/loader_test.go b/builder/sensitive/internal/loader_test.go index 919347b8b..26868e8cc 100644 --- a/builder/sensitive/internal/loader_test.go +++ b/builder/sensitive/internal/loader_test.go @@ -21,19 +21,9 @@ func (to *testObserver) Update(data *SensitiveWordData) error { } func TestConfigLoader(t *testing.T) { - cfg := &config.Config{ - SensitiveCheck: struct { - Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE" default:"false"` - AccessKeyID string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_ID"` - AccessKeySecret string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_SECRET"` - Region string `env:"STARHUB_SERVER_SENSITIVE_CHECK_REGION"` - Endpoint string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENDPOINT" default:"oss-cn-beijing.aliyuncs.com"` - EnableSSL bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE_SSL" default:"true"` - DictDir string `env:"STARHUB_SERVER_SENSITIVE_CHECK_DICT_DIR" default:"/starhub-bin/vocabulary"` - }{ - DictDir: "./config.yaml", - }, - } + cfg := &config.Config{} + cfg.SensitiveCheck.DictDir = "./config.yaml" + loader := NewConfigLoader(cfg) observer := &testObserver{} @@ -58,19 +48,9 @@ func TestConfigLoader(t *testing.T) { } func TestMultipleObservers(t *testing.T) { - cfg := &config.Config{ - SensitiveCheck: struct { - Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE" default:"false"` - AccessKeyID string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_ID"` - AccessKeySecret string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_SECRET"` - Region string `env:"STARHUB_SERVER_SENSITIVE_CHECK_REGION"` - Endpoint string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENDPOINT" default:"oss-cn-beijing.aliyuncs.com"` - EnableSSL bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE_SSL" default:"true"` - DictDir string `env:"STARHUB_SERVER_SENSITIVE_CHECK_DICT_DIR" default:"/starhub-bin/vocabulary"` - }{ - DictDir: "./config.yaml", - }, - } + cfg := &config.Config{} + cfg.SensitiveCheck.DictDir = "./config.yaml" + loader := NewConfigLoader(cfg) observer1 := &testObserver{} @@ -94,19 +74,9 @@ func TestMultipleObservers(t *testing.T) { } func TestUnsubscribe(t *testing.T) { - cfg := &config.Config{ - SensitiveCheck: struct { - Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE" default:"false"` - AccessKeyID string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_ID"` - AccessKeySecret string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_SECRET"` - Region string `env:"STARHUB_SERVER_SENSITIVE_CHECK_REGION"` - Endpoint string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENDPOINT" default:"oss-cn-beijing.aliyuncs.com"` - EnableSSL bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE_SSL" default:"true"` - DictDir string `env:"STARHUB_SERVER_SENSITIVE_CHECK_DICT_DIR" default:"/starhub-bin/vocabulary"` - }{ - DictDir: "./config.yaml", - }, - } + cfg := &config.Config{} + cfg.SensitiveCheck.DictDir = "./config.yaml" + loader := NewConfigLoader(cfg) observer := &testObserver{} diff --git a/builder/sensitive/sensitive_checker.go b/builder/sensitive/sensitive_checker.go index e0a2ead21..5f701343c 100644 --- a/builder/sensitive/sensitive_checker.go +++ b/builder/sensitive/sensitive_checker.go @@ -10,7 +10,7 @@ type SensitiveChecker interface { PassTextCheck(ctx context.Context, scenario types.SensitiveScenario, text string) (*CheckResult, error) PassImageCheck(ctx context.Context, scenario types.SensitiveScenario, ossBucketName, ossObjectName string) (*CheckResult, error) PassImageURLCheck(ctx context.Context, scenario types.SensitiveScenario, imageURL string) (*CheckResult, error) - PassLLMCheck(ctx context.Context, scenario types.SensitiveScenario, text string, sessionId string, accountId string) (*CheckResult, error) + PassLLMCheck(ctx context.Context, req *types.LLMCheckRequest) (*CheckResult, error) } type ImageCheckReq struct { diff --git a/builder/store/database/license_test.go b/builder/store/database/license_test.go index 512eddc9e..f7e7ee595 100644 --- a/builder/store/database/license_test.go +++ b/builder/store/database/license_test.go @@ -26,7 +26,7 @@ func TestLicenseStore_CRUD(t *testing.T) { Product: "test", Edition: "standard", MaxUser: 10, - StartTime: time.Now(), + StartTime: time.Now().Add(-1 * time.Minute), ExpireTime: time.Now().Add(-1 * time.Hour), UserUUID: "test-user-uuid", Issuer: "tester", diff --git a/builder/store/database/llm_config.go b/builder/store/database/llm_config.go index 83c26afd5..6489924f3 100644 --- a/builder/store/database/llm_config.go +++ b/builder/store/database/llm_config.go @@ -25,6 +25,10 @@ type LLMConfig struct { Enabled bool `bun:",notnull" json:"enabled"` Provider string `bun:"," json:"provider"` Metadata map[string]any `bun:",type:jsonb,nullzero" json:"metadata"` + // NeedSensitiveCheck controls whether requests for this model should go + // through sensitive content detection in aigateway. Set to false to skip + // the check (e.g. for guard models or trusted internal models). + NeedSensitiveCheck bool `bun:",notnull,default:true" json:"need_sensitive_check"` times } diff --git a/builder/store/database/migrations/20260330154817_add_need_sensitive_check_to_llm_configs.down.sql b/builder/store/database/migrations/20260330154817_add_need_sensitive_check_to_llm_configs.down.sql new file mode 100644 index 000000000..c16e1d893 --- /dev/null +++ b/builder/store/database/migrations/20260330154817_add_need_sensitive_check_to_llm_configs.down.sql @@ -0,0 +1,5 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE llm_configs DROP COLUMN IF EXISTS need_sensitive_check; diff --git a/builder/store/database/migrations/20260330154817_add_need_sensitive_check_to_llm_configs.up.sql b/builder/store/database/migrations/20260330154817_add_need_sensitive_check_to_llm_configs.up.sql new file mode 100644 index 000000000..a8666b0fb --- /dev/null +++ b/builder/store/database/migrations/20260330154817_add_need_sensitive_check_to_llm_configs.up.sql @@ -0,0 +1,5 @@ +SET statement_timeout = 0; + +--bun:split + +ALTER TABLE llm_configs ADD COLUMN IF NOT EXISTS need_sensitive_check BOOLEAN DEFAULT true; diff --git a/common/config/config.go b/common/config/config.go index aedaa9dac..dc7417612 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -123,13 +123,37 @@ type Config struct { } SensitiveCheck struct { - Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE" default:"false"` - AccessKeyID string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_ID"` - AccessKeySecret string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_SECRET"` - Region string `env:"STARHUB_SERVER_SENSITIVE_CHECK_REGION"` - Endpoint string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENDPOINT" default:"oss-cn-beijing.aliyuncs.com"` - EnableSSL bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE_SSL" default:"true"` - DictDir string `env:"STARHUB_SERVER_SENSITIVE_CHECK_DICT_DIR" default:"/starhub-bin/vocabulary"` + Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE" default:"false"` + AccessKeyID string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_ID"` + AccessKeySecret string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_SECRET"` + Region string `env:"STARHUB_SERVER_SENSITIVE_CHECK_REGION"` + Endpoint string `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENDPOINT" default:"oss-cn-beijing.aliyuncs.com"` + EnableSSL bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE_SSL" default:"true"` + DictDir string `env:"STARHUB_SERVER_SENSITIVE_CHECK_DICT_DIR" default:"/starhub-bin/vocabulary"` + CheckChain []string `env:"STARHUB_SERVER_SENSITIVE_CHECK_CHECK_CHAIN" default:"[ac_automaton,mutable_ac_automaton,aliyun_green]"` + StreamCheckMode string `env:"STARHUB_SERVER_SENSITIVE_CHECK_STREAM_CHECK_MODE" default:"async"` // sync | async + AsyncBufferMaxChars int `env:"STARHUB_SERVER_SENSITIVE_CHECK_ASYNC_BUFFER_MAX_CHARS" default:"50"` + + LLM struct { + Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_ENABLE" default:"false"` + Endpoint string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_ENDPOINT"` + APIKey string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_API_KEY"` + GuardModel string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_GUARD_MODEL" default:"Qwen/Qwen3Guard-Gen-0.6B"` + GuardStreamModel string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_GUARD_STREAM_MODEL" default:"Qwen/Qwen/Qwen3Guard-Gen-Stream-0.6B"` + TimeoutMS int `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_TIMEOUT_MS" default:"3000"` + MaxTokens int `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_MAX_TOKENS" default:"128"` + Temperature float64 `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_TEMPERATURE" default:"0"` + ResponseMode string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_RESPONSE_MODE" default:"json_or_text"` + SafetyRegex string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_SAFETY_REGEX" default:"Safety:\\s*(Safe|Unsafe|Controversial)"` + } + + StreamContextCache struct { + Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_STREAM_CONTEXT_CACHE_ENABLE" default:"true"` + Backend string `env:"STARHUB_SERVER_SENSITIVE_CHECK_STREAM_CONTEXT_CACHE_BACKEND" default:"memory"` // redis | memory + TTLSeconds int `env:"STARHUB_SERVER_SENSITIVE_CHECK_STREAM_CONTEXT_CACHE_TTL_SECONDS" default:"120"` + MaxChunks int `env:"STARHUB_SERVER_SENSITIVE_CHECK_STREAM_CONTEXT_CACHE_MAX_CHUNKS" default:"12"` + MaxChars int `env:"STARHUB_SERVER_SENSITIVE_CHECK_STREAM_CONTEXT_CACHE_MAX_CHARS" default:"2000"` + } } JWT struct { diff --git a/common/types/prompt.go b/common/types/prompt.go index 799dab345..e6b19d25a 100644 --- a/common/types/prompt.go +++ b/common/types/prompt.go @@ -49,6 +49,21 @@ type LLMReqBody struct { Messages []LLMMessage `json:"messages"` Stream bool `json:"stream"` Temperature float64 `json:"temperature"` + MaxTokens int `json:"max_tokens,omitempty"` + RawJSON string `json:"-"` +} + +type LLMCheckRequest struct { + Scenario SensitiveScenario `json:"scenario"` + Text string `json:"text"` + SessionId string `json:"session_id,omitempty"` + AccountId string `json:"account_id,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + RawJSON string `json:"raw_json,omitempty"` + Resumable bool `json:"resumable,omitempty"` + ModelName string `json:"-"` + Role string `json:"role,omitempty"` + Stream bool `json:"stream,omitempty"` } type ConversationMessageReq struct { diff --git a/moderation/checker/text_file_checker_test.go b/moderation/checker/text_file_checker_test.go index 236450fd1..4c3b6c047 100644 --- a/moderation/checker/text_file_checker_test.go +++ b/moderation/checker/text_file_checker_test.go @@ -20,15 +20,9 @@ func TestTextFileChecker_Run(t *testing.T) { t.Run("contains sensitive words", func(t *testing.T) { mockChecker := mocksens.NewMockSensitiveChecker(t) - InitWithContentChecker(&config.Config{SensitiveCheck: struct { - Enable bool "env:\"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE\" default:\"false\"" - AccessKeyID string "env:\"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_ID\"" - AccessKeySecret string "env:\"STARHUB_SERVER_SENSITIVE_CHECK_ACCESS_KEY_SECRET\"" - Region string "env:\"STARHUB_SERVER_SENSITIVE_CHECK_REGION\"" - Endpoint string "env:\"STARHUB_SERVER_SENSITIVE_CHECK_ENDPOINT\" default:\"oss-cn-beijing.aliyuncs.com\"" - EnableSSL bool "env:\"STARHUB_SERVER_SENSITIVE_CHECK_ENABLE_SSL\" default:\"true\"" - DictDir string "env:\"STARHUB_SERVER_SENSITIVE_CHECK_DICT_DIR\" default:\"/starhub-bin/vocabulary\"" - }{Enable: true}}, mockChecker) + cfg := &config.Config{} + cfg.SensitiveCheck.Enable = true + InitWithContentChecker(cfg, mockChecker) mockChecker.EXPECT().PassTextCheck(mock.Anything, types.ScenarioCommentDetection, "This text contains sensitive word."). Return(&sensitive.CheckResult{IsSensitive: true, Reason: "contains sensitive word"}, nil) checker := NewTextFileChecker() diff --git a/moderation/component/sensitive.go b/moderation/component/sensitive.go index 5d43dd20a..4022a307b 100644 --- a/moderation/component/sensitive.go +++ b/moderation/component/sensitive.go @@ -2,37 +2,63 @@ package component import ( "context" + "log/slog" + "strings" + gwtype "opencsg.com/csghub-server/aigateway/types" "opencsg.com/csghub-server/builder/sensitive" "opencsg.com/csghub-server/common/config" "opencsg.com/csghub-server/common/types" ) +type CheckProvider string + +const ( + CheckProviderACAutomaton CheckProvider = "ac_automaton" + CheckProviderMutableACAutomaton CheckProvider = "mutable_ac_automaton" + CheckProviderAliyunGreen CheckProvider = "aliyun_green" + CheckProviderLLMOpenAI CheckProvider = "guard_llm" +) + type SensitiveComponent interface { PassTextCheck(ctx context.Context, scenario types.SensitiveScenario, text string) (*sensitive.CheckResult, error) PassImageCheck(ctx context.Context, scenario types.SensitiveScenario, ossBucketName, ossObjectName string) (*sensitive.CheckResult, error) PassImageURLCheck(ctx context.Context, scenario types.SensitiveScenario, imageURL string) (*sensitive.CheckResult, error) - PassStreamCheck(ctx context.Context, scenario types.SensitiveScenario, text, id string) (*sensitive.CheckResult, error) - PassLLMQueryCheck(ctx context.Context, scenario types.SensitiveScenario, text, id string) (*sensitive.CheckResult, error) + // PassStreamCheck check stream chunk text + PassStreamCheck(ctx context.Context, req *types.LLMCheckRequest) (*sensitive.CheckResult, error) + // PassLLMQueryCheck check LLM prompt text + PassLLMQueryCheck(ctx context.Context, req *types.LLMCheckRequest) (*sensitive.CheckResult, error) } type SensitiveComponentImpl struct { checker sensitive.SensitiveChecker + cfg *config.Config } -func NewSensitiveComponent(checker sensitive.SensitiveChecker) SensitiveComponent { - return SensitiveComponentImpl{ - checker: checker, +func NewSensitiveComponentFromConfig(config *config.Config) SensitiveComponent { + var opts []sensitive.ChainOption + + for _, provider := range config.SensitiveCheck.CheckChain { + p := strings.TrimSpace(provider) + switch p { + case string(CheckProviderACAutomaton): + opts = append(opts, sensitive.WithACAutomaton(sensitive.LoadFromConfig(config))) + case string(CheckProviderMutableACAutomaton): + opts = append(opts, sensitive.WithMutableACAutomaton(sensitive.LoadFromDB())) + case string(CheckProviderAliyunGreen): + opts = append(opts, sensitive.WithAliYunChecker()) + case string(CheckProviderLLMOpenAI): + opts = append(opts, sensitive.WithOpenAILLMChecker()) + default: + if p != "" { + slog.Warn("unknown sensitive check provider ignored", slog.String("provider", p)) + } + } } -} -func NewSensitiveComponentFromConfig(config *config.Config) SensitiveComponent { return SensitiveComponentImpl{ - checker: sensitive.NewChainChecker(config, - sensitive.WithACAutomaton(sensitive.LoadFromConfig(config)), - sensitive.WithMutableACAutomaton(sensitive.LoadFromDB()), - sensitive.WithAliYunChecker(), - ), + checker: sensitive.NewChainChecker(config, opts...), + cfg: config, } } @@ -44,12 +70,19 @@ func (c SensitiveComponentImpl) PassImageCheck(ctx context.Context, scenario typ return c.checker.PassImageCheck(ctx, scenario, ossBucketName, ossObjectName) } -func (c SensitiveComponentImpl) PassStreamCheck(ctx context.Context, scenario types.SensitiveScenario, text, id string) (*sensitive.CheckResult, error) { - return c.checker.PassLLMCheck(ctx, scenario, text, id, "") +func (c SensitiveComponentImpl) PassStreamCheck(ctx context.Context, req *types.LLMCheckRequest) (*sensitive.CheckResult, error) { + req.ModelName = c.cfg.SensitiveCheck.LLM.GuardStreamModel + req.Role = string(gwtype.RoleAssistant) + return c.checker.PassLLMCheck(ctx, req) } -func (c SensitiveComponentImpl) PassLLMQueryCheck(ctx context.Context, scenario types.SensitiveScenario, text, id string) (*sensitive.CheckResult, error) { - return c.checker.PassLLMCheck(ctx, scenario, text, "", id) +func (c SensitiveComponentImpl) PassLLMQueryCheck(ctx context.Context, req *types.LLMCheckRequest) (*sensitive.CheckResult, error) { + req.ModelName = c.cfg.SensitiveCheck.LLM.GuardModel + if req.Stream { + req.ModelName = c.cfg.SensitiveCheck.LLM.GuardStreamModel + } + req.Role = string(gwtype.RoleUser) + return c.checker.PassLLMCheck(ctx, req) } func (c SensitiveComponentImpl) PassImageURLCheck(ctx context.Context, scenario types.SensitiveScenario, imageURL string) (*sensitive.CheckResult, error) { diff --git a/moderation/component/sensitive_test.go b/moderation/component/sensitive_test.go index 964c69778..4e7a51732 100644 --- a/moderation/component/sensitive_test.go +++ b/moderation/component/sensitive_test.go @@ -48,30 +48,64 @@ func TestSensitiveComponentImpl_PassImageURLCheck(t *testing.T) { func TestSensitiveComponentImpl_PassLLMQueryCheck(t *testing.T) { mockSeneitive := mock_sensitive.NewMockSensitiveChecker(t) + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.GuardModel = "test-model" component := SensitiveComponentImpl{ checker: mockSeneitive, + cfg: cfg, } - mockSeneitive.EXPECT().PassLLMCheck(mock.Anything, types.ScenarioNicknameDetection, "你好", "", "123"). + mockSeneitive.EXPECT().PassLLMCheck(mock.Anything, &types.LLMCheckRequest{ + Scenario: types.ScenarioNicknameDetection, + Text: "你好", + AccountId: "123", + MaxTokens: 0, + RawJSON: "", + ModelName: "test-model", + Role: "user", + }). Return(&sensitive.CheckResult{ IsSensitive: false, }, nil) - result, err := component.PassLLMQueryCheck(context.Background(), - types.ScenarioNicknameDetection, "你好", "123") + result, err := component.PassLLMQueryCheck(context.Background(), &types.LLMCheckRequest{ + Scenario: types.ScenarioNicknameDetection, + Text: "你好", + AccountId: "123", + MaxTokens: 0, + RawJSON: "", + Role: "user", + }) assert.NoError(t, err) assert.False(t, result.IsSensitive) } func TestSensitiveComponentImpl_PassStreamCheck(t *testing.T) { mockSeneitive := mock_sensitive.NewMockSensitiveChecker(t) + cfg := &config.Config{} + cfg.SensitiveCheck.LLM.GuardStreamModel = "test-stream-model" component := SensitiveComponentImpl{ checker: mockSeneitive, + cfg: cfg, } - mockSeneitive.EXPECT().PassLLMCheck(mock.Anything, types.ScenarioNicknameDetection, "你好", "123", ""). + mockSeneitive.EXPECT().PassLLMCheck(mock.Anything, &types.LLMCheckRequest{ + Scenario: types.ScenarioNicknameDetection, + Text: "你好", + SessionId: "123", + MaxTokens: 0, + RawJSON: "", + ModelName: "test-stream-model", + Role: "assistant", + }). Return(&sensitive.CheckResult{ IsSensitive: false, }, nil) - result, err := component.PassStreamCheck(context.Background(), - types.ScenarioNicknameDetection, "你好", "123") + result, err := component.PassStreamCheck(context.Background(), &types.LLMCheckRequest{ + Scenario: types.ScenarioNicknameDetection, + Text: "你好", + SessionId: "123", + MaxTokens: 0, + RawJSON: "", + Role: "assistant", + }) assert.NoError(t, err) assert.False(t, result.IsSensitive) } diff --git a/moderation/handler/sensitive.go b/moderation/handler/sensitive.go index 6abcd38d5..518167b22 100644 --- a/moderation/handler/sensitive.go +++ b/moderation/handler/sensitive.go @@ -79,23 +79,14 @@ func (h *SensitiveHandler) Image(ctx *gin.Context) { } func (h *SensitiveHandler) LlmResp(ctx *gin.Context) { - type request struct { - Service string `json:"Service"` - ServiceParameters struct { - Content string `json:"content"` - SessionId string `json:"sessionId"` - } `json:"ServiceParameters"` - } - var ( - r request - err error - ) - if err = ctx.ShouldBindJSON(&r); err != nil { + var req types.LLMCheckRequest + if err := ctx.ShouldBindJSON(&req); err != nil { slog.Error("Bad request format", slog.String("err", err.Error())) httpbase.BadRequest(ctx, err.Error()) return } - result, err := h.c.PassStreamCheck(ctx, types.ScenarioLLMResModeration, r.ServiceParameters.Content, r.ServiceParameters.SessionId) + + result, err := h.c.PassStreamCheck(ctx, &req) if err != nil { httpbase.ServerError(ctx, err) return @@ -104,26 +95,18 @@ func (h *SensitiveHandler) LlmResp(ctx *gin.Context) { } func (h *SensitiveHandler) LlmPrompt(ctx *gin.Context) { - type request struct { - Service string `json:"Service"` - ServiceParameters struct { - Content string `json:"content"` - AccountId string `json:"accountId"` - } `json:"ServiceParameters"` - } - var ( - r request - err error - ) - if err = ctx.ShouldBindJSON(&r); err != nil { + var req types.LLMCheckRequest + if err := ctx.ShouldBindJSON(&req); err != nil { slog.Error("Bad request format", slog.String("err", err.Error())) httpbase.BadRequest(ctx, err.Error()) return } - result, err := h.c.PassLLMQueryCheck(ctx, types.ScenarioLLMQueryModeration, r.ServiceParameters.Content, r.ServiceParameters.AccountId) + + result, err := h.c.PassLLMQueryCheck(ctx, &req) if err != nil { httpbase.ServerError(ctx, err) return } + httpbase.OK(ctx, result) } diff --git a/moderation/handler/sensitive_test.go b/moderation/handler/sensitive_test.go index f92a56ba4..3c6737448 100644 --- a/moderation/handler/sensitive_test.go +++ b/moderation/handler/sensitive_test.go @@ -224,21 +224,23 @@ func TestSensitiveHandler_LlmResp(t *testing.T) { t.Run("success", func(t *testing.T) { // Prepare request body - reqBody := map[string]interface{}{ - "Service": "LLMResponseModeration", - "ServiceParameters": map[string]interface{}{ - "content": "This is a safe response", - "sessionId": "test-session-123", - }, + reqBody := types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "This is a safe response", + SessionId: "test-session-123", } reqBodyBytes, _ := json.Marshal(reqBody) // Set mock expectation mockSensitiveComponent.EXPECT().PassStreamCheck( mock.Anything, - types.ScenarioLLMResModeration, - "This is a safe response", - "test-session-123", + &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "This is a safe response", + SessionId: "test-session-123", + MaxTokens: 0, + RawJSON: "", + }, ).Return(successResult, nil).Once() // Create request @@ -277,22 +279,25 @@ func TestSensitiveHandler_LlmResp(t *testing.T) { t.Run("server error from sensitive component", func(t *testing.T) { // Prepare request body - reqBody := map[string]interface{}{ - "Service": "LLMResponseModeration", - "ServiceParameters": map[string]interface{}{ - "content": "This is a test response", - "sessionId": "test-session-123", - }, + reqBody := types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "This is a test response", + SessionId: "test-session-123", } reqBodyBytes, _ := json.Marshal(reqBody) // Set mock expectation to return error expectedErr := assert.AnError + mockSensitiveComponent.EXPECT().PassStreamCheck( mock.Anything, - types.ScenarioLLMResModeration, - "This is a test response", - "test-session-123", + &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "This is a test response", + SessionId: "test-session-123", + MaxTokens: 0, + RawJSON: "", + }, ).Return(nil, expectedErr).Once() // Create request @@ -309,12 +314,10 @@ func TestSensitiveHandler_LlmResp(t *testing.T) { t.Run("sensitive content detected", func(t *testing.T) { // Prepare request body - reqBody := map[string]interface{}{ - "Service": "LLMResponseModeration", - "ServiceParameters": map[string]interface{}{ - "content": "This is sensitive content", - "sessionId": "test-session-123", - }, + reqBody := types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "This is sensitive content", + SessionId: "test-session-123", } reqBodyBytes, _ := json.Marshal(reqBody) @@ -327,9 +330,13 @@ func TestSensitiveHandler_LlmResp(t *testing.T) { // Set mock expectation mockSensitiveComponent.EXPECT().PassStreamCheck( mock.Anything, - types.ScenarioLLMResModeration, - "This is sensitive content", - "test-session-123", + &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMResModeration, + Text: "This is sensitive content", + SessionId: "test-session-123", + MaxTokens: 0, + RawJSON: "", + }, ).Return(sensitiveResult, nil).Once() // Create request @@ -371,21 +378,23 @@ func TestSensitiveHandler_LlmPrompt(t *testing.T) { t.Run("success", func(t *testing.T) { // Prepare request body - reqBody := map[string]interface{}{ - "Service": "LLMPromptModeration", - "ServiceParameters": map[string]interface{}{ - "content": "This is a safe prompt", - "accountId": "test-account-123", - }, + reqBody := types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "This is a safe prompt", + AccountId: "test-account-123", } reqBodyBytes, _ := json.Marshal(reqBody) // Set mock expectation mockSensitiveComponent.EXPECT().PassLLMQueryCheck( mock.Anything, - types.ScenarioLLMQueryModeration, - "This is a safe prompt", - "test-account-123", + &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "This is a safe prompt", + AccountId: "test-account-123", + MaxTokens: 0, + RawJSON: "", + }, ).Return(successResult, nil).Once() // Create request @@ -424,12 +433,10 @@ func TestSensitiveHandler_LlmPrompt(t *testing.T) { t.Run("server error from sensitive component", func(t *testing.T) { // Prepare request body - reqBody := map[string]interface{}{ - "Service": "LLMPromptModeration", - "ServiceParameters": map[string]interface{}{ - "content": "This is a test prompt", - "accountId": "test-account-123", - }, + reqBody := types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "This is a test prompt", + AccountId: "test-account-123", } reqBodyBytes, _ := json.Marshal(reqBody) @@ -437,9 +444,13 @@ func TestSensitiveHandler_LlmPrompt(t *testing.T) { expectedErr := assert.AnError mockSensitiveComponent.EXPECT().PassLLMQueryCheck( mock.Anything, - types.ScenarioLLMQueryModeration, - "This is a test prompt", - "test-account-123", + &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "This is a test prompt", + AccountId: "test-account-123", + MaxTokens: 0, + RawJSON: "", + }, ).Return(nil, expectedErr).Once() // Create request @@ -456,12 +467,10 @@ func TestSensitiveHandler_LlmPrompt(t *testing.T) { t.Run("sensitive content detected", func(t *testing.T) { // Prepare request body - reqBody := map[string]interface{}{ - "Service": "LLMPromptModeration", - "ServiceParameters": map[string]interface{}{ - "content": "This is sensitive prompt", - "accountId": "test-account-123", - }, + reqBody := types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "This is sensitive prompt", + AccountId: "test-account-123", } reqBodyBytes, _ := json.Marshal(reqBody) @@ -474,9 +483,13 @@ func TestSensitiveHandler_LlmPrompt(t *testing.T) { // Set mock expectation mockSensitiveComponent.EXPECT().PassLLMQueryCheck( mock.Anything, - types.ScenarioLLMQueryModeration, - "This is sensitive prompt", - "test-account-123", + &types.LLMCheckRequest{ + Scenario: types.ScenarioLLMQueryModeration, + Text: "This is sensitive prompt", + AccountId: "test-account-123", + MaxTokens: 0, + RawJSON: "", + }, ).Return(sensitiveResult, nil).Once() // Create request diff --git a/moderation/router/api.go b/moderation/router/api.go index 5df11bcb7..c47da0c9b 100644 --- a/moderation/router/api.go +++ b/moderation/router/api.go @@ -2,6 +2,7 @@ package router import ( "fmt" + "opencsg.com/csghub-server/builder/instrumentation" "github.com/gin-contrib/pprof"