@@ -2,10 +2,12 @@ package component
22
33import (
44 "context"
5+ "strings"
56 "testing"
67 "time"
78
89 lru "github.com/hashicorp/golang-lru/v2"
10+ "github.com/openai/openai-go/v3"
911 "github.com/stretchr/testify/assert"
1012 "github.com/stretchr/testify/mock"
1113 "opencsg.com/csghub-server/aigateway/types"
@@ -395,3 +397,83 @@ func TestInitStreamChecker(t *testing.T) {
395397 assert .Equal (t , 100 , checker .maxChars )
396398 })
397399}
400+
401+ func TestModerationImpl_checkLLMPrompt (t * testing.T ) {
402+ ctx := context .Background ()
403+ mockSvcClient := new (MockModerationSvcClient )
404+
405+ modImpl := & moderationImpl {
406+ modSvcClient : mockSvcClient ,
407+ maxContentLength : 10 ,
408+ }
409+
410+ t .Run ("short content" , func (t * testing.T ) {
411+ mockSvcClient .ExpectedCalls = nil
412+ mockSvcClient .On ("PassLLMPromptCheck" , mock .Anything , mock .Anything ).Return (& rpc.CheckResult {IsSensitive : false }, nil ).Once ()
413+
414+ res , err := modImpl .checkLLMPrompt (ctx , "short" , "test-key" , false )
415+ assert .NoError (t , err )
416+ assert .False (t , res .IsSensitive )
417+ })
418+
419+ t .Run ("long content chunking" , func (t * testing.T ) {
420+ mockSvcClient .ExpectedCalls = nil
421+ // 20 chars, max length is 10, so it will be chunked
422+ // splitContentIntoChunksByWindow logic: if chunk size is maxContentLength (10)?
423+ // wait, splitContentIntoChunksByWindow splits by 2000!
424+ // Actually, splitContentIntoChunksByWindow has slidingWindowSize = 2000 hardcoded in moderation.go
425+
426+ // If we use 3000 chars, it will be chunked
427+ modImpl .maxContentLength = 2000
428+ longText := strings .Repeat ("a" , 3000 )
429+ mockSvcClient .On ("PassLLMPromptCheck" , mock .Anything , mock .Anything ).Return (& rpc.CheckResult {IsSensitive : false }, nil )
430+
431+ res , err := modImpl .checkLLMPrompt (ctx , longText , "test-key" , false )
432+ assert .NoError (t , err )
433+ assert .False (t , res .IsSensitive )
434+ })
435+ }
436+
437+ func TestModerationImpl_CheckChatPrompts (t * testing.T ) {
438+ ctx := context .Background ()
439+ mockSvcClient := new (MockModerationSvcClient )
440+
441+ modImpl := & moderationImpl {
442+ modSvcClient : mockSvcClient ,
443+ maxContentLength : 2000 ,
444+ }
445+
446+ t .Run ("nil modSvcClient" , func (t * testing.T ) {
447+ emptyModImpl := & moderationImpl {modSvcClient : nil }
448+ res , err := emptyModImpl .CheckChatPrompts (ctx , nil , "uuid" , false )
449+ assert .NoError (t , err )
450+ assert .False (t , res .IsSensitive )
451+ })
452+
453+ t .Run ("normal message" , func (t * testing.T ) {
454+ mockSvcClient .ExpectedCalls = nil
455+ mockSvcClient .On ("PassLLMPromptCheck" , mock .Anything , mock .Anything ).Return (& rpc.CheckResult {IsSensitive : false }, nil ).Once ()
456+
457+ messages := []openai.ChatCompletionMessageParamUnion {
458+ openai .UserMessage ("Hello" ),
459+ }
460+
461+ res , err := modImpl .CheckChatPrompts (ctx , messages , "uuid" , false )
462+ assert .NoError (t , err )
463+ assert .False (t , res .IsSensitive )
464+ })
465+
466+ t .Run ("sensitive message" , func (t * testing.T ) {
467+ mockSvcClient .ExpectedCalls = nil
468+ mockSvcClient .On ("PassLLMPromptCheck" , mock .Anything , mock .Anything ).Return (& rpc.CheckResult {IsSensitive : true , Reason : "toxic" }, nil ).Once ()
469+
470+ messages := []openai.ChatCompletionMessageParamUnion {
471+ openai .UserMessage ("Bad words" ),
472+ }
473+
474+ res , err := modImpl .CheckChatPrompts (ctx , messages , "uuid" , false )
475+ assert .NoError (t , err )
476+ assert .True (t , res .IsSensitive )
477+ assert .Equal (t , "toxic" , res .Reason )
478+ })
479+ }
0 commit comments