Skip to content

Commit 5d6811b

Browse files
committed
feat: add comprehensive prompt filter system with 73 security rules
- Add prompt filter module with weighted scoring (50+ threshold) - Implement 73 built-in rules covering malware, exploits, reverse engineering, etc. - Support Chinese and English keywords for dual-language detection - Add frontend UI with category filtering, batch enable/disable, and pagination - Integrate filter into /v1/responses, /v1/chat/completions, /v1/messages endpoints - Add admin API for rule management and filter logs - Support monitor/warn/block modes with configurable thresholds - Add database tables for filter logs and audit trail
1 parent 29b1c94 commit 5d6811b

24 files changed

Lines changed: 3533 additions & 13 deletions

admin/handler.go

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/codex2api/database"
2727
"github.com/codex2api/proxy"
2828
"github.com/codex2api/security"
29+
"github.com/codex2api/security/promptfilter"
2930
"github.com/gin-gonic/gin"
3031
"github.com/tidwall/gjson"
3132
)
@@ -133,6 +134,10 @@ func (h *Handler) RegisterRoutes(r *gin.Engine) {
133134
api.GET("/ops/overview", h.GetOpsOverview)
134135
api.GET("/settings", h.GetSettings)
135136
api.PUT("/settings", h.UpdateSettings)
137+
api.GET("/prompt-filter/logs", h.ListPromptFilterLogs)
138+
api.DELETE("/prompt-filter/logs", h.ClearPromptFilterLogs)
139+
api.POST("/prompt-filter/test", h.TestPromptFilter)
140+
api.GET("/prompt-filter/rules", h.GetPromptFilterRules)
136141
api.GET("/models", h.ListModels)
137142
api.POST("/models/sync", h.SyncModels)
138143
api.GET("/image-prompts", h.ListImagePromptTemplates)
@@ -297,10 +302,10 @@ type accountResponse struct {
297302
Locked bool `json:"locked"`
298303
AllowedAPIKeyIDs []int64 `json:"allowed_api_key_ids"`
299304
// 图片配额信息
300-
ImageQuotaRemaining *int `json:"image_quota_remaining,omitempty"`
301-
ImageQuotaTotal *int `json:"image_quota_total,omitempty"`
302-
TodayUsedCount *int `json:"today_used_count,omitempty"`
303-
ImageQuotaResetAt string `json:"image_quota_reset_at,omitempty"`
305+
ImageQuotaRemaining *int `json:"image_quota_remaining,omitempty"`
306+
ImageQuotaTotal *int `json:"image_quota_total,omitempty"`
307+
TodayUsedCount *int `json:"today_used_count,omitempty"`
308+
ImageQuotaResetAt string `json:"image_quota_reset_at,omitempty"`
304309
}
305310

306311
type accountUsageWindow struct {
@@ -2090,6 +2095,15 @@ type settingsResponse struct {
20902095
ModelMapping string `json:"model_mapping"`
20912096
ResinURL string `json:"resin_url"`
20922097
ResinPlatformName string `json:"resin_platform_name"`
2098+
PromptFilterEnabled bool `json:"prompt_filter_enabled"`
2099+
PromptFilterMode string `json:"prompt_filter_mode"`
2100+
PromptFilterThreshold int `json:"prompt_filter_threshold"`
2101+
PromptFilterStrictThreshold int `json:"prompt_filter_strict_threshold"`
2102+
PromptFilterLogMatches bool `json:"prompt_filter_log_matches"`
2103+
PromptFilterMaxTextLength int `json:"prompt_filter_max_text_length"`
2104+
PromptFilterSensitiveWords string `json:"prompt_filter_sensitive_words"`
2105+
PromptFilterCustomPatterns string `json:"prompt_filter_custom_patterns"`
2106+
PromptFilterDisabledPatterns string `json:"prompt_filter_disabled_patterns"`
20932107
}
20942108

20952109
type updateSettingsReq struct {
@@ -2116,6 +2130,15 @@ type updateSettingsReq struct {
21162130
ModelMapping *string `json:"model_mapping"`
21172131
ResinURL *string `json:"resin_url"`
21182132
ResinPlatformName *string `json:"resin_platform_name"`
2133+
PromptFilterEnabled *bool `json:"prompt_filter_enabled"`
2134+
PromptFilterMode *string `json:"prompt_filter_mode"`
2135+
PromptFilterThreshold *int `json:"prompt_filter_threshold"`
2136+
PromptFilterStrictThreshold *int `json:"prompt_filter_strict_threshold"`
2137+
PromptFilterLogMatches *bool `json:"prompt_filter_log_matches"`
2138+
PromptFilterMaxTextLength *int `json:"prompt_filter_max_text_length"`
2139+
PromptFilterSensitiveWords *string `json:"prompt_filter_sensitive_words"`
2140+
PromptFilterCustomPatterns *string `json:"prompt_filter_custom_patterns"`
2141+
PromptFilterDisabledPatterns *string `json:"prompt_filter_disabled_patterns"`
21192142
}
21202143

21212144
// GetSettings 获取当前系统设置
@@ -2133,6 +2156,7 @@ func (h *Handler) GetSettings(c *gin.Context) {
21332156
resinURL = dbSettings.ResinURL
21342157
resinPlatformName = dbSettings.ResinPlatformName
21352158
}
2159+
promptFilterCfg := h.store.GetPromptFilterConfig()
21362160
c.JSON(http.StatusOK, settingsResponse{
21372161
MaxConcurrency: h.store.GetMaxConcurrency(),
21382162
GlobalRPM: h.rateLimiter.GetRPM(),
@@ -2162,6 +2186,15 @@ func (h *Handler) GetSettings(c *gin.Context) {
21622186
ModelMapping: h.store.GetModelMapping(),
21632187
ResinURL: resinURL,
21642188
ResinPlatformName: resinPlatformName,
2189+
PromptFilterEnabled: promptFilterCfg.Enabled,
2190+
PromptFilterMode: promptFilterCfg.Mode,
2191+
PromptFilterThreshold: promptFilterCfg.Threshold,
2192+
PromptFilterStrictThreshold: promptFilterCfg.StrictThreshold,
2193+
PromptFilterLogMatches: promptFilterCfg.LogMatches,
2194+
PromptFilterMaxTextLength: promptFilterCfg.MaxTextLength,
2195+
PromptFilterSensitiveWords: promptFilterCfg.SensitiveWords,
2196+
PromptFilterCustomPatterns: promptfilter.MarshalCustomPatterns(promptFilterCfg.CustomPatterns),
2197+
PromptFilterDisabledPatterns: promptfilter.MarshalDisabledPatterns(promptFilterCfg.DisabledPatterns),
21652198
})
21662199
}
21672200

@@ -2363,6 +2396,64 @@ func (h *Handler) UpdateSettings(c *gin.Context) {
23632396
log.Printf("设置已更新: model_mapping")
23642397
}
23652398

2399+
promptFilterCfg := h.store.GetPromptFilterConfig()
2400+
promptFilterChanged := false
2401+
if req.PromptFilterEnabled != nil {
2402+
promptFilterCfg.Enabled = *req.PromptFilterEnabled
2403+
promptFilterChanged = true
2404+
}
2405+
if req.PromptFilterMode != nil {
2406+
promptFilterCfg.Mode = *req.PromptFilterMode
2407+
promptFilterChanged = true
2408+
}
2409+
if req.PromptFilterThreshold != nil {
2410+
promptFilterCfg.Threshold = *req.PromptFilterThreshold
2411+
promptFilterChanged = true
2412+
}
2413+
if req.PromptFilterStrictThreshold != nil {
2414+
promptFilterCfg.StrictThreshold = *req.PromptFilterStrictThreshold
2415+
promptFilterChanged = true
2416+
}
2417+
if req.PromptFilterLogMatches != nil {
2418+
promptFilterCfg.LogMatches = *req.PromptFilterLogMatches
2419+
promptFilterChanged = true
2420+
}
2421+
if req.PromptFilterMaxTextLength != nil {
2422+
promptFilterCfg.MaxTextLength = *req.PromptFilterMaxTextLength
2423+
promptFilterChanged = true
2424+
}
2425+
if req.PromptFilterSensitiveWords != nil {
2426+
promptFilterCfg.SensitiveWords = *req.PromptFilterSensitiveWords
2427+
promptFilterChanged = true
2428+
}
2429+
if req.PromptFilterCustomPatterns != nil {
2430+
patterns, err := promptfilter.ParseCustomPatterns(*req.PromptFilterCustomPatterns)
2431+
if err != nil {
2432+
writeError(c, http.StatusBadRequest, "Prompt 检查自定义规则 JSON 无效: "+err.Error())
2433+
return
2434+
}
2435+
promptFilterCfg.CustomPatterns = patterns
2436+
promptFilterChanged = true
2437+
}
2438+
if req.PromptFilterDisabledPatterns != nil {
2439+
disabled, err := promptfilter.ParseDisabledPatterns(*req.PromptFilterDisabledPatterns)
2440+
if err != nil {
2441+
writeError(c, http.StatusBadRequest, "Prompt 检查禁用规则 JSON 无效: "+err.Error())
2442+
return
2443+
}
2444+
promptFilterCfg.DisabledPatterns = disabled
2445+
promptFilterChanged = true
2446+
}
2447+
if promptFilterChanged {
2448+
promptFilterCfg = promptfilter.NormalizeConfig(promptFilterCfg)
2449+
if _, err := promptfilter.NewEngine(promptFilterCfg); err != nil {
2450+
writeError(c, http.StatusBadRequest, "Prompt 检查规则无效: "+err.Error())
2451+
return
2452+
}
2453+
h.store.SetPromptFilterConfig(promptFilterCfg)
2454+
log.Printf("设置已更新: prompt_filter enabled=%t mode=%s threshold=%d", promptFilterCfg.Enabled, promptFilterCfg.Mode, promptFilterCfg.Threshold)
2455+
}
2456+
23662457
// Resin 粘性代理池配置
23672458
resinURL := ""
23682459
resinPlatformName := ""
@@ -2417,6 +2508,15 @@ func (h *Handler) UpdateSettings(c *gin.Context) {
24172508
ModelMapping: h.store.GetModelMapping(),
24182509
ResinURL: resinURL,
24192510
ResinPlatformName: resinPlatformName,
2511+
PromptFilterEnabled: promptFilterCfg.Enabled,
2512+
PromptFilterMode: promptFilterCfg.Mode,
2513+
PromptFilterThreshold: promptFilterCfg.Threshold,
2514+
PromptFilterStrictThreshold: promptFilterCfg.StrictThreshold,
2515+
PromptFilterLogMatches: promptFilterCfg.LogMatches,
2516+
PromptFilterMaxTextLength: promptFilterCfg.MaxTextLength,
2517+
PromptFilterSensitiveWords: promptFilterCfg.SensitiveWords,
2518+
PromptFilterCustomPatterns: promptfilter.MarshalCustomPatterns(promptFilterCfg.CustomPatterns),
2519+
PromptFilterDisabledPatterns: promptfilter.MarshalDisabledPatterns(promptFilterCfg.DisabledPatterns),
24202520
})
24212521
if err != nil {
24222522
log.Printf("无法持久化保存设置: %v", err)
@@ -2465,6 +2565,15 @@ func (h *Handler) UpdateSettings(c *gin.Context) {
24652565
ModelMapping: h.store.GetModelMapping(),
24662566
ResinURL: resinURL,
24672567
ResinPlatformName: resinPlatformName,
2568+
PromptFilterEnabled: promptFilterCfg.Enabled,
2569+
PromptFilterMode: promptFilterCfg.Mode,
2570+
PromptFilterThreshold: promptFilterCfg.Threshold,
2571+
PromptFilterStrictThreshold: promptFilterCfg.StrictThreshold,
2572+
PromptFilterLogMatches: promptFilterCfg.LogMatches,
2573+
PromptFilterMaxTextLength: promptFilterCfg.MaxTextLength,
2574+
PromptFilterSensitiveWords: promptFilterCfg.SensitiveWords,
2575+
PromptFilterCustomPatterns: promptfilter.MarshalCustomPatterns(promptFilterCfg.CustomPatterns),
2576+
PromptFilterDisabledPatterns: promptfilter.MarshalDisabledPatterns(promptFilterCfg.DisabledPatterns),
24682577
})
24692578
}
24702579

admin/image_studio.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,9 @@ func (h *Handler) CreateImageGenerationJob(c *gin.Context) {
277277
}
278278
paramsJSON, _ := json.Marshal(req)
279279
keyID, keyName, keyMasked := imageJobAPIKeyMeta(apiKey)
280+
if h.inspectImageStudioPromptFilter(c, proxy.AppendImageStyleToPrompt(req.Prompt, req.Style), req.Model, keyID, keyName, keyMasked) {
281+
return
282+
}
280283
jobID, err := h.db.InsertImageGenerationJob(ctx, database.ImageGenerationJobInput{
281284
Prompt: req.Prompt,
282285
ParamsJSON: string(paramsJSON),

admin/prompt_filter.go

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package admin
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"strconv"
7+
"strings"
8+
"time"
9+
10+
"github.com/codex2api/database"
11+
"github.com/codex2api/security/promptfilter"
12+
"github.com/gin-gonic/gin"
13+
)
14+
15+
type promptFilterLogsResponse struct {
16+
Logs []*database.PromptFilterLog `json:"logs"`
17+
Total int `json:"total"`
18+
Page int `json:"page"`
19+
PageSize int `json:"page_size"`
20+
}
21+
22+
type promptFilterTestRequest struct {
23+
Text string `json:"text"`
24+
Endpoint string `json:"endpoint"`
25+
Model string `json:"model"`
26+
}
27+
28+
type promptFilterTestResponse struct {
29+
Verdict promptfilter.Verdict `json:"verdict"`
30+
}
31+
32+
type promptFilterRuleItem struct {
33+
Name string `json:"name"`
34+
Pattern string `json:"pattern"`
35+
Weight int `json:"weight"`
36+
Category string `json:"category,omitempty"`
37+
Strict bool `json:"strict,omitempty"`
38+
Enabled bool `json:"enabled"`
39+
Builtin bool `json:"builtin"`
40+
}
41+
42+
type promptFilterRulesResponse struct {
43+
BuiltinPatterns []promptFilterRuleItem `json:"builtin_patterns"`
44+
CustomPatterns []promptfilter.PatternConfig `json:"custom_patterns"`
45+
DisabledPatterns []string `json:"disabled_patterns"`
46+
}
47+
48+
func (h *Handler) inspectImageStudioPromptFilter(c *gin.Context, text string, model string, keyID int64, keyName string, keyMasked string) bool {
49+
if h == nil || h.store == nil {
50+
return false
51+
}
52+
cfg := h.store.GetPromptFilterConfig()
53+
verdict := promptfilter.InspectText(text, cfg)
54+
if verdict.Action == promptfilter.ActionWarn {
55+
c.Header("X-Prompt-Filter-Warning", verdict.Reason)
56+
return false
57+
}
58+
if verdict.Action != promptfilter.ActionBlock {
59+
return false
60+
}
61+
h.recordPromptFilterLog(c, &database.PromptFilterLogInput{
62+
Source: "local_filter",
63+
Endpoint: "/api/admin/images/jobs",
64+
Model: model,
65+
Action: verdict.Action,
66+
Mode: verdict.Mode,
67+
Score: verdict.Score,
68+
Threshold: verdict.Threshold,
69+
MatchedPatterns: promptfilter.MatchesJSON(verdict.Matched),
70+
TextPreview: verdict.TextPreview,
71+
APIKeyID: keyID,
72+
APIKeyName: keyName,
73+
APIKeyMasked: keyMasked,
74+
ClientIP: c.ClientIP(),
75+
})
76+
writeError(c, http.StatusBadRequest, "Prompt 被检查规则拦截")
77+
return true
78+
}
79+
80+
func (h *Handler) recordPromptFilterLog(c *gin.Context, input *database.PromptFilterLogInput) {
81+
if h == nil || h.db == nil || input == nil {
82+
return
83+
}
84+
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
85+
defer cancel()
86+
_ = h.db.InsertPromptFilterLog(ctx, input)
87+
}
88+
89+
func (h *Handler) ListPromptFilterLogs(c *gin.Context) {
90+
page := positiveQueryInt(c, "page", 1)
91+
pageSize := positiveQueryInt(c, "page_size", positiveQueryInt(c, "limit", 100))
92+
apiKeyID := int64(0)
93+
if raw := strings.TrimSpace(c.Query("api_key_id")); raw != "" {
94+
if parsed, err := strconv.ParseInt(raw, 10, 64); err == nil && parsed > 0 {
95+
apiKeyID = parsed
96+
}
97+
}
98+
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
99+
defer cancel()
100+
logs, total, err := h.db.ListPromptFilterLogsPage(ctx, database.PromptFilterLogQuery{
101+
Page: page,
102+
PageSize: pageSize,
103+
Source: c.Query("source"),
104+
Action: c.Query("action"),
105+
Endpoint: c.Query("endpoint"),
106+
Model: c.Query("model"),
107+
APIKeyID: apiKeyID,
108+
Query: c.Query("q"),
109+
})
110+
if err != nil {
111+
writeInternalError(c, err)
112+
return
113+
}
114+
if logs == nil {
115+
logs = []*database.PromptFilterLog{}
116+
}
117+
c.JSON(http.StatusOK, promptFilterLogsResponse{Logs: logs, Total: total, Page: page, PageSize: pageSize})
118+
}
119+
120+
func (h *Handler) ClearPromptFilterLogs(c *gin.Context) {
121+
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
122+
defer cancel()
123+
if err := h.db.ClearPromptFilterLogs(ctx); err != nil {
124+
writeInternalError(c, err)
125+
return
126+
}
127+
writeMessage(c, http.StatusOK, "Prompt 检查日志已清空")
128+
}
129+
130+
func (h *Handler) TestPromptFilter(c *gin.Context) {
131+
var req promptFilterTestRequest
132+
if err := c.ShouldBindJSON(&req); err != nil {
133+
writeError(c, http.StatusBadRequest, "请求体无效")
134+
return
135+
}
136+
req.Text = strings.TrimSpace(req.Text)
137+
if req.Text == "" {
138+
writeError(c, http.StatusBadRequest, "text 不能为空")
139+
return
140+
}
141+
if len([]rune(req.Text)) > 20000 {
142+
writeError(c, http.StatusBadRequest, "text 不能超过 20000 个字符")
143+
return
144+
}
145+
cfg := h.store.GetPromptFilterConfig()
146+
cfg.Enabled = true
147+
verdict := promptfilter.InspectText(req.Text, cfg)
148+
c.JSON(http.StatusOK, promptFilterTestResponse{Verdict: verdict})
149+
}
150+
151+
func (h *Handler) GetPromptFilterRules(c *gin.Context) {
152+
cfg := h.store.GetPromptFilterConfig()
153+
disabled := map[string]bool{}
154+
for _, name := range cfg.DisabledPatterns {
155+
disabled[strings.ToLower(strings.TrimSpace(name))] = true
156+
}
157+
builtin := promptfilter.BuiltinPatternConfigs()
158+
items := make([]promptFilterRuleItem, 0, len(builtin))
159+
for _, pattern := range builtin {
160+
items = append(items, promptFilterRuleItem{
161+
Name: pattern.Name,
162+
Pattern: pattern.Pattern,
163+
Weight: pattern.Weight,
164+
Category: pattern.Category,
165+
Strict: pattern.Strict,
166+
Enabled: !disabled[strings.ToLower(strings.TrimSpace(pattern.Name))],
167+
Builtin: true,
168+
})
169+
}
170+
c.JSON(http.StatusOK, promptFilterRulesResponse{
171+
BuiltinPatterns: items,
172+
CustomPatterns: cfg.CustomPatterns,
173+
DisabledPatterns: cfg.DisabledPatterns,
174+
})
175+
}
176+
177+
func positiveQueryInt(c *gin.Context, key string, fallback int) int {
178+
raw := strings.TrimSpace(c.Query(key))
179+
if raw == "" {
180+
return fallback
181+
}
182+
parsed, err := strconv.Atoi(raw)
183+
if err != nil || parsed <= 0 {
184+
return fallback
185+
}
186+
return parsed
187+
}

0 commit comments

Comments
 (0)