Skip to content

Commit abaf90e

Browse files
csg-pr-botDev Agentcemeng
authored
Fix the max context length for aliyun-green-checker (#988)
Co-authored-by: Dev Agent <dev-agent@example.com> Co-authored-by: cemeng <cemengzhang@yntengyun.com>
1 parent e09f7dc commit abaf90e

11 files changed

Lines changed: 168 additions & 29 deletions

File tree

aigateway/component/moderation.go

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ import (
2121
)
2222

2323
const (
24-
// max content length
25-
maxContentLength = 6144
26-
// sliding window size
27-
slidingWindowSize = 2000
24+
// max content length for moderation
25+
defaultMaxContentLength = 2000 // sliding window size
26+
slidingWindowSize = 2000
2827
// cache ttl
2928
cacheTTL = 24 * time.Hour
3029
// moderation cache prefix
@@ -59,10 +58,11 @@ type StreamChecker interface {
5958
}
6059

6160
type moderationImpl struct {
62-
modSvcClient rpc.ModerationSvcClient
63-
cacheClient cache.RedisClient
64-
config *config.Config
65-
streamChecker StreamChecker
61+
modSvcClient rpc.ModerationSvcClient
62+
cacheClient cache.RedisClient
63+
config *config.Config
64+
streamChecker StreamChecker
65+
maxContentLength int
6666
}
6767

6868
type syncStreamChecker struct {
@@ -262,10 +262,15 @@ func NewModerationImpl(config *config.Config) Moderation {
262262
}
263263

264264
func NewModerationImplWithClient(config *config.Config, modSvcClient rpc.ModerationSvcClient, cacheClient cache.RedisClient) Moderation {
265+
maxContentLength := config.SensitiveCheck.MaxContentLength
266+
if config.SensitiveCheck.MaxContentLength <= 0 {
267+
maxContentLength = defaultMaxContentLength
268+
}
265269
modImpl := &moderationImpl{
266-
modSvcClient: modSvcClient,
267-
cacheClient: cacheClient,
268-
config: config,
270+
modSvcClient: modSvcClient,
271+
cacheClient: cacheClient,
272+
maxContentLength: maxContentLength,
273+
config: config,
269274
}
270275

271276
initStreamChecker(modImpl)
@@ -478,7 +483,7 @@ func (modImpl *moderationImpl) CheckChatPrompts(ctx context.Context, messages []
478483
func (modImpl *moderationImpl) checkLLMPrompt(ctx context.Context, content, key string, isStream bool) (*rpc.CheckResult, error) {
479484
content = strings.ReplaceAll(content, `\\n`, "\n")
480485
content = strings.ReplaceAll(content, `\n`, "")
481-
if len(content) < maxContentLength {
486+
if len(content) < modImpl.maxContentLength {
482487
return modImpl.checkSingleChunk(ctx, content, key, isStream)
483488
}
484489

@@ -544,7 +549,7 @@ func (modImpl *moderationImpl) checkLLMPrompt(ctx context.Context, content, key
544549
separatorLen = 1 // for "."
545550
}
546551

547-
if buffer.Len()+separatorLen+len(chunk) > maxContentLength && buffer.Len() > 0 {
552+
if buffer.Len()+separatorLen+len(chunk) > modImpl.maxContentLength && buffer.Len() > 0 {
548553
result, err := modImpl.checkBuffer(ctx, buffer.String(), currentBufferChunks, key, isStream)
549554
if err != nil {
550555
return nil, fmt.Errorf("failed to call moderation on buffer: %w", err)

aigateway/component/moderation_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package component
22

33
import (
44
"context"
5+
"strings"
56
"testing"
67
"time"
78

89
lru "github.com/hashicorp/golang-lru/v2"
10+
"github.com/openai/openai-go/v3"
911
"github.com/stretchr/testify/assert"
1012
"github.com/stretchr/testify/mock"
1113
"opencsg.com/csghub-server/aigateway/types"
@@ -395,3 +397,83 @@ func TestInitStreamChecker(t *testing.T) {
395397
assert.Equal(t, 100, checker.maxChars)
396398
})
397399
}
400+
401+
func TestModerationImpl_checkLLMPrompt(t *testing.T) {
402+
ctx := context.Background()
403+
mockSvcClient := new(MockModerationSvcClient)
404+
405+
modImpl := &moderationImpl{
406+
modSvcClient: mockSvcClient,
407+
maxContentLength: 10,
408+
}
409+
410+
t.Run("short content", func(t *testing.T) {
411+
mockSvcClient.ExpectedCalls = nil
412+
mockSvcClient.On("PassLLMPromptCheck", mock.Anything, mock.Anything).Return(&rpc.CheckResult{IsSensitive: false}, nil).Once()
413+
414+
res, err := modImpl.checkLLMPrompt(ctx, "short", "test-key", false)
415+
assert.NoError(t, err)
416+
assert.False(t, res.IsSensitive)
417+
})
418+
419+
t.Run("long content chunking", func(t *testing.T) {
420+
mockSvcClient.ExpectedCalls = nil
421+
// 20 chars, max length is 10, so it will be chunked
422+
// splitContentIntoChunksByWindow logic: if chunk size is maxContentLength (10)?
423+
// wait, splitContentIntoChunksByWindow splits by 2000!
424+
// Actually, splitContentIntoChunksByWindow has slidingWindowSize = 2000 hardcoded in moderation.go
425+
426+
// If we use 3000 chars, it will be chunked
427+
modImpl.maxContentLength = 2000
428+
longText := strings.Repeat("a", 3000)
429+
mockSvcClient.On("PassLLMPromptCheck", mock.Anything, mock.Anything).Return(&rpc.CheckResult{IsSensitive: false}, nil)
430+
431+
res, err := modImpl.checkLLMPrompt(ctx, longText, "test-key", false)
432+
assert.NoError(t, err)
433+
assert.False(t, res.IsSensitive)
434+
})
435+
}
436+
437+
func TestModerationImpl_CheckChatPrompts(t *testing.T) {
438+
ctx := context.Background()
439+
mockSvcClient := new(MockModerationSvcClient)
440+
441+
modImpl := &moderationImpl{
442+
modSvcClient: mockSvcClient,
443+
maxContentLength: 2000,
444+
}
445+
446+
t.Run("nil modSvcClient", func(t *testing.T) {
447+
emptyModImpl := &moderationImpl{modSvcClient: nil}
448+
res, err := emptyModImpl.CheckChatPrompts(ctx, nil, "uuid", false)
449+
assert.NoError(t, err)
450+
assert.False(t, res.IsSensitive)
451+
})
452+
453+
t.Run("normal message", func(t *testing.T) {
454+
mockSvcClient.ExpectedCalls = nil
455+
mockSvcClient.On("PassLLMPromptCheck", mock.Anything, mock.Anything).Return(&rpc.CheckResult{IsSensitive: false}, nil).Once()
456+
457+
messages := []openai.ChatCompletionMessageParamUnion{
458+
openai.UserMessage("Hello"),
459+
}
460+
461+
res, err := modImpl.CheckChatPrompts(ctx, messages, "uuid", false)
462+
assert.NoError(t, err)
463+
assert.False(t, res.IsSensitive)
464+
})
465+
466+
t.Run("sensitive message", func(t *testing.T) {
467+
mockSvcClient.ExpectedCalls = nil
468+
mockSvcClient.On("PassLLMPromptCheck", mock.Anything, mock.Anything).Return(&rpc.CheckResult{IsSensitive: true, Reason: "toxic"}, nil).Once()
469+
470+
messages := []openai.ChatCompletionMessageParamUnion{
471+
openai.UserMessage("Bad words"),
472+
}
473+
474+
res, err := modImpl.CheckChatPrompts(ctx, messages, "uuid", false)
475+
assert.NoError(t, err)
476+
assert.True(t, res.IsSensitive)
477+
assert.Equal(t, "toxic", res.Reason)
478+
})
479+
}

aigateway/handler/openai.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,10 +362,9 @@ func (h *OpenAIHandlerImpl) Chat(c *gin.Context) {
362362
key := fmt.Sprintf("%s:%s", userUUID, modelID)
363363
result, err := h.modComponent.CheckChatPrompts(c.Request.Context(), chatReq.Messages, key, chatReq.Stream)
364364
if err != nil {
365-
c.String(http.StatusInternalServerError, fmt.Errorf("failed to call moderation error:%w", err).Error())
366-
return
365+
slog.ErrorContext(c.Request.Context(), "failed to call moderation", slog.Any("error", err))
367366
}
368-
if result.IsSensitive {
367+
if result != nil && result.IsSensitive {
369368
handleSensitiveResponse(c, chatReq.Stream, result)
370369
return
371370
}

aigateway/handler/openai_test.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -521,9 +521,28 @@ func TestOpenAIHandler_Chat(t *testing.T) {
521521
_ = json.Unmarshal(body, &expectReq)
522522
tester.mocks.moderationComp.EXPECT().CheckChatPrompts(mock.Anything, expectReq.Messages, "testuuid:"+model.ID, false).
523523
Return(nil, errors.New("some error"))
524+
llmTokenCounter := mocktoken.NewMockChatTokenCounter(t)
525+
tester.mocks.tokenCounterFactory.EXPECT().NewChat(
526+
token.CreateParam{
527+
Endpoint: model.Endpoint,
528+
Host: "",
529+
Model: "model1",
530+
ImageID: model.ImageID,
531+
Provider: model.Provider,
532+
}).
533+
Return(llmTokenCounter)
534+
llmTokenCounter.EXPECT().AppendPrompts(expectReq.Messages).Return()
535+
var wg sync.WaitGroup
536+
wg.Add(1)
537+
tester.mocks.openAIComp.EXPECT().RecordUsage(mock.Anything, "testuuid", model, llmTokenCounter).
538+
RunAndReturn(func(ctx context.Context, uuid string, model *types.Model, counter token.Counter) error {
539+
wg.Done()
540+
return nil
541+
})
524542
tester.handler.Chat(c)
543+
wg.Wait()
525544

526-
assert.Equal(t, http.StatusInternalServerError, w.Code)
545+
assert.Equal(t, http.StatusOK, w.Code)
527546
})
528547
t.Run("success", func(t *testing.T) {
529548
tester, c, w := setupTest(t)

aigateway/handler/response_writer_wrapper_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func TestHandleSensitiveResponse(t *testing.T) {
107107
w := httptest.NewRecorder()
108108
ctx, _ := gin.CreateTestContext(w)
109109
ctx.Request = httptest.NewRequest("GET", "/", nil)
110-
110+
111111
checkResult := &rpc.CheckResult{Reason: "test reason"}
112112
handleSensitiveResponse(ctx, true, checkResult)
113113

@@ -122,7 +122,7 @@ func TestHandleSensitiveResponse(t *testing.T) {
122122
w := httptest.NewRecorder()
123123
ctx, _ := gin.CreateTestContext(w)
124124
ctx.Request = httptest.NewRequest("GET", "/", nil)
125-
125+
126126
checkResult := &rpc.CheckResult{Reason: "test reason"}
127127
handleSensitiveResponse(ctx, false, checkResult)
128128

builder/rpc/moderation_svc_client.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ import (
77
"opencsg.com/csghub-server/api/httpbase"
88
"opencsg.com/csghub-server/common/errorx"
99
"opencsg.com/csghub-server/common/types"
10+
utils "opencsg.com/csghub-server/common/utils/common"
1011
)
1112

13+
const PRINT_STRING_LEN = 1000
14+
1215
type ModerationSvcClient interface {
1316
PassTextCheck(ctx context.Context, scenario types.SensitiveScenario, text string) (*CheckResult, error)
1417
PassImageCheck(ctx context.Context, scenario types.SensitiveScenario, ossBucketName, ossObjectName string) (*CheckResult, error)
@@ -48,6 +51,7 @@ func (c *ModerationSvcHttpClient) PassTextCheck(ctx context.Context, scenario ty
4851
resp.Data = &CheckResult{}
4952
err := c.hc.Post(ctx, path, req, &resp)
5053
if err != nil {
54+
slog.ErrorContext(ctx, "call moderation service failed", slog.String("error", err.Error()), slog.Any("req", req))
5155
return nil, errorx.RemoteSvcFail(err,
5256
errorx.Ctx().
5357
Set("service", "moderation service").
@@ -65,7 +69,8 @@ func (c *ModerationSvcHttpClient) PassLLMRespCheck(ctx context.Context, req type
6569
resp.Data = &CheckResult{}
6670
err := c.hc.Post(ctx, path, req, &resp)
6771
if err != nil {
68-
slog.Error("call moderation service failed", slog.String("error", err.Error()))
72+
req.Text = utils.TruncStringByRune(req.Text, PRINT_STRING_LEN)
73+
slog.ErrorContext(ctx, "call moderation service failed", slog.String("error", err.Error()), slog.Any("req", req))
6974
return nil, errorx.RemoteSvcFail(err,
7075
errorx.Ctx().
7176
Set("service", "moderation service").
@@ -140,7 +145,7 @@ func (c *ModerationSvcHttpClient) SubmitRepoCheck(ctx context.Context, repoType
140145
var resp httpbase.R
141146
err := c.hc.Post(ctx, path, req, &resp)
142147
if err != nil {
143-
slog.Error("call moderation service failed", slog.String("error", err.Error()))
148+
slog.ErrorContext(ctx, "call moderation service failed", slog.String("error", err.Error()), slog.Any("req", req))
144149
return errorx.RemoteSvcFail(err,
145150
errorx.Ctx().
146151
Set("service", "moderation service").
@@ -156,7 +161,8 @@ func (c *ModerationSvcHttpClient) PassLLMPromptCheck(ctx context.Context, req ty
156161
resp.Data = &CheckResult{}
157162
err := c.hc.Post(ctx, path, req, &resp)
158163
if err != nil {
159-
slog.Error("call moderation service failed", slog.String("error", err.Error()))
164+
req.Text = utils.TruncStringByRune(req.Text, PRINT_STRING_LEN)
165+
slog.ErrorContext(ctx, "call moderation service failed", slog.String("error", err.Error()), slog.Any("req", req))
160166
return nil, errorx.RemoteSvcFail(err,
161167
errorx.Ctx().
162168
Set("service", "moderation service").

builder/sensitive/guard_llm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func (c *OpenAILLMChecker) doCheck(ctx context.Context, req *types.LLMCheckReque
112112
if c.config.SensitiveCheck.LLM.APIKey != "" {
113113
headers["Authorization"] = "Bearer " + c.config.SensitiveCheck.LLM.APIKey
114114
}
115-
115+
116116
// Retry mechanism for 429
117117
var content string
118118
var err error

common/config/config.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,15 @@ type Config struct {
133133
CheckChain []string `env:"STARHUB_SERVER_SENSITIVE_CHECK_CHECK_CHAIN" default:"[ac_automaton,mutable_ac_automaton,aliyun_green]"`
134134
StreamCheckMode string `env:"STARHUB_SERVER_SENSITIVE_CHECK_STREAM_CHECK_MODE" default:"async"` // sync | async
135135
AsyncBufferMaxChars int `env:"STARHUB_SERVER_SENSITIVE_CHECK_ASYNC_BUFFER_MAX_CHARS" default:"50"`
136+
// aliyun green max content length: 2000 | qwen guard max content length: 7000
137+
MaxContentLength int `env:"STARHUB_SERVER_SENSITIVE_CHECK_MAX_CONTENT_LENGTH" default:"2000"`
136138

137139
LLM struct {
138140
Enable bool `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_ENABLE" default:"false"`
139141
Endpoint string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_ENDPOINT"`
140142
APIKey string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_API_KEY"`
141143
GuardModel string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_GUARD_MODEL" default:"Qwen/Qwen3Guard-Gen-0.6B"`
142-
GuardStreamModel string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_GUARD_STREAM_MODEL" default:"Qwen/Qwen/Qwen3Guard-Gen-Stream-0.6B"`
144+
GuardStreamModel string `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_GUARD_STREAM_MODEL" default:"Qwen/Qwen3Guard-Gen-Stream-0.6B"`
143145
TimeoutMS int `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_TIMEOUT_MS" default:"3000"`
144146
MaxTokens int `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_MAX_TOKENS" default:"128"`
145147
Temperature float64 `env:"STARHUB_SERVER_SENSITIVE_CHECK_LLM_TEMPERATURE" default:"0"`

common/utils/common/string.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ func TruncString(s string, limit int) string {
3535
return string(s1)
3636
}
3737

38+
func TruncStringByRune(s string, limit int) string {
39+
runes := []rune(s)
40+
if len(runes) <= limit {
41+
return s
42+
}
43+
44+
// Reserve 3 runes for "..."
45+
if limit <= 3 {
46+
return string(runes[:limit])
47+
}
48+
49+
return string(runes[:limit-3]) + "..."
50+
}
51+
3852
func MD5Hash(s string) string {
3953
hash := md5.New()
4054
hash.Write([]byte(s))

component/llm_service_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,9 +140,9 @@ func TestLLMServiceComponent_UpdateLLMConfig(t *testing.T) {
140140
newName := "new-model"
141141
metadata := map[string]any{"tasks": []any{"text-to-image"}}
142142
req := &types.UpdateLLMConfigReq{
143-
ID: 123,
143+
ID: 123,
144144
ModelName: &newName,
145-
Metadata: &metadata,
145+
Metadata: &metadata,
146146
}
147147
dbLLMConfig := &database.LLMConfig{
148148
ID: 123,

0 commit comments

Comments
 (0)