Skip to content

Commit 90df2ab

Browse files
committed
fix hook matcher validation and preview sanitization
1 parent aea182a commit 90df2ab

5 files changed

Lines changed: 209 additions & 6 deletions

File tree

docs/runtime-hooks-design.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ runtime 内置 `HookPointCapability` 作为唯一真源,定义每个点位是
129129
说明:
130130

131131
- `arguments_contains` 基于 `tool_arguments_preview` 字段匹配,不读取 `tool_arguments` 原文。
132-
- `warn_on_tool_call` 的旧参数 `params.tool_name/tool_names` 仍兼容;未配置 `match` 时会自动桥接为 matcher。
133-
-`match` 与旧参数共存,以 `match` 为准,并发出 `hook_notification` 迁移提示事件。
132+
- `warn_on_tool_call` 当前要求显式配置 `match`;旧参数 `params.tool_name/tool_names` 不再承担匹配语义。
134133

135134
### trust gate
136135

internal/runtime/hooks/matcher.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, erro
5454
if len(raw) == 0 {
5555
return nil, nil
5656
}
57+
if err := validateHookMatcherFields(raw); err != nil {
58+
return nil, err
59+
}
5760
if !HasHookMatcherConfig(raw) {
5861
return nil, fmt.Errorf("match contains no recognized matcher fields (expected: tool_name, tool_name_regex, arguments_contains)")
5962
}
@@ -104,6 +107,26 @@ func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, erro
104107
return matcher, nil
105108
}
106109

110+
// validateHookMatcherFields 校验 matcher 配置中不存在未支持字段,避免拼写错误被静默忽略。
111+
func validateHookMatcherFields(raw map[string]any) error {
112+
if len(raw) == 0 {
113+
return nil
114+
}
115+
for key := range raw {
116+
normalized := strings.ToLower(strings.TrimSpace(key))
117+
switch normalized {
118+
case hookMatcherFieldToolName, hookMatcherFieldToolNameRegex, hookMatcherFieldArgumentsContains:
119+
continue
120+
default:
121+
return fmt.Errorf(
122+
"match contains unknown field %q (allowed: tool_name, tool_name_regex, arguments_contains)",
123+
key,
124+
)
125+
}
126+
}
127+
return nil
128+
}
129+
107130
// IsEmpty 判断 matcher 是否包含可执行维度。
108131
func (m *HookMatcher) IsEmpty() bool {
109132
if m == nil {

internal/runtime/hooks/matcher_test.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,12 @@ func TestCompileHookMatcherValidation(t *testing.T) {
111111
}); err == nil {
112112
t.Fatal("expected completely unknown matcher field to be rejected")
113113
}
114+
if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{
115+
"tool_name": "bash",
116+
"tool_names": []any{"filesystem"},
117+
}); err == nil {
118+
t.Fatal("expected mixed matcher fields with typo to be rejected")
119+
}
114120

115121
if _, err := CompileHookMatcher(HookPointBeforeToolCall, nil); err != nil {
116122
t.Fatal("nil raw should succeed with nil matcher")
@@ -306,8 +312,8 @@ func TestCompileHookMatcherRegexWhitespaceSkipped(t *testing.T) {
306312
t.Parallel()
307313

308314
matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{
309-
"tool_name": "bash",
310-
"tool_name_regex": []string{" ", "\t"},
315+
"tool_name": "bash",
316+
"tool_name_regex": []string{" ", "\t"},
311317
})
312318
if err != nil {
313319
t.Fatalf("CompileHookMatcher() error = %v", err)
@@ -320,7 +326,6 @@ func TestCompileHookMatcherRegexWhitespaceSkipped(t *testing.T) {
320326
}
321327
}
322328

323-
324329
func TestCompileHookMatcherRegexOnly(t *testing.T) {
325330
t.Parallel()
326331

internal/runtime/toolexec.go

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sort"
1111
"strings"
1212
"sync"
13+
"unicode"
1314

1415
"neo-code/internal/checkpoint"
1516
providertypes "neo-code/internal/provider/types"
@@ -766,10 +767,104 @@ func buildToolArgumentsPreview(arguments string) string {
766767
if trimmed == "" {
767768
return ""
768769
}
769-
masked := hookToolArgumentsSensitivePattern.ReplaceAllString(trimmed, `$1=***`)
770+
masked := sanitizeHookToolArguments(trimmed)
770771
return truncateHookTextByChars(masked, hookToolArgumentsPreviewMaxChars)
771772
}
772773

774+
// sanitizeHookToolArguments 优先按 JSON 结构递归脱敏,非 JSON 输入回退为轻量正则脱敏。
775+
func sanitizeHookToolArguments(arguments string) string {
776+
if masked, ok := sanitizeHookToolArgumentsJSON(arguments); ok {
777+
return masked
778+
}
779+
return hookToolArgumentsSensitivePattern.ReplaceAllString(arguments, `$1=***`)
780+
}
781+
782+
// sanitizeHookToolArgumentsJSON 尝试解析 JSON 并按敏感键递归替换值。
783+
func sanitizeHookToolArgumentsJSON(arguments string) (string, bool) {
784+
var decoded any
785+
if err := json.Unmarshal([]byte(arguments), &decoded); err != nil {
786+
return "", false
787+
}
788+
sanitized := maskHookToolArgumentValue(decoded)
789+
encoded, err := json.Marshal(sanitized)
790+
if err != nil {
791+
return "", false
792+
}
793+
return string(encoded), true
794+
}
795+
796+
// maskHookToolArgumentValue 递归处理 JSON 节点,对敏感键对应的值统一替换为 "***"。
797+
func maskHookToolArgumentValue(value any) any {
798+
switch typed := value.(type) {
799+
case map[string]any:
800+
masked := make(map[string]any, len(typed))
801+
for key, item := range typed {
802+
if isSensitiveHookToolArgumentKey(key) {
803+
masked[key] = "***"
804+
continue
805+
}
806+
masked[key] = maskHookToolArgumentValue(item)
807+
}
808+
return masked
809+
case []any:
810+
masked := make([]any, len(typed))
811+
for index, item := range typed {
812+
masked[index] = maskHookToolArgumentValue(item)
813+
}
814+
return masked
815+
default:
816+
return value
817+
}
818+
}
819+
820+
// isSensitiveHookToolArgumentKey 判断参数键名是否属于敏感信息字段。
821+
func isSensitiveHookToolArgumentKey(key string) bool {
822+
tokens := tokenizeHookToolArgumentKey(key)
823+
if len(tokens) == 0 {
824+
return false
825+
}
826+
for index, token := range tokens {
827+
switch token {
828+
case "password", "passwd", "secret", "token", "auth", "authorization":
829+
return true
830+
case "apikey", "accesskey", "authtoken", "accesstoken":
831+
return true
832+
case "api", "access":
833+
if index+1 < len(tokens) && tokens[index+1] == "key" {
834+
return true
835+
}
836+
case "key":
837+
if index > 0 && (tokens[index-1] == "api" || tokens[index-1] == "access") {
838+
return true
839+
}
840+
}
841+
}
842+
return false
843+
}
844+
845+
// tokenizeHookToolArgumentKey 将参数键拆分为小写词元,兼容 snake/kebab/camelCase。
846+
func tokenizeHookToolArgumentKey(key string) []string {
847+
trimmed := strings.TrimSpace(key)
848+
if trimmed == "" {
849+
return nil
850+
}
851+
var builder strings.Builder
852+
var prev rune
853+
for _, current := range trimmed {
854+
switch {
855+
case unicode.IsLetter(current) || unicode.IsDigit(current):
856+
if unicode.IsUpper(current) && unicode.IsLower(prev) {
857+
builder.WriteByte(' ')
858+
}
859+
builder.WriteRune(unicode.ToLower(current))
860+
default:
861+
builder.WriteByte(' ')
862+
}
863+
prev = current
864+
}
865+
return strings.Fields(builder.String())
866+
}
867+
773868
// truncateHookTextByChars 按字符长度截断文本,避免 metadata 放大。
774869
func truncateHookTextByChars(text string, maxChars int) string {
775870
if maxChars <= 0 {
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package runtime
2+
3+
import (
4+
"strings"
5+
"testing"
6+
)
7+
8+
func TestBuildToolArgumentsPreviewMaskJSONSensitiveFields(t *testing.T) {
9+
t.Parallel()
10+
11+
raw := `{"api_key":"sk-123","password":"p@ss","nested":{"secret":"abc"},"safe":"ok"}`
12+
preview := buildToolArgumentsPreview(raw)
13+
if strings.Contains(preview, "sk-123") {
14+
t.Fatalf("preview leaked api_key: %q", preview)
15+
}
16+
if strings.Contains(preview, "p@ss") {
17+
t.Fatalf("preview leaked password: %q", preview)
18+
}
19+
if strings.Contains(preview, `"secret":"abc"`) {
20+
t.Fatalf("preview leaked nested secret: %q", preview)
21+
}
22+
if !strings.Contains(preview, `"api_key":"***"`) {
23+
t.Fatalf("preview should mask api_key: %q", preview)
24+
}
25+
if !strings.Contains(preview, `"password":"***"`) {
26+
t.Fatalf("preview should mask password: %q", preview)
27+
}
28+
if !strings.Contains(preview, `"secret":"***"`) {
29+
t.Fatalf("preview should mask nested secret: %q", preview)
30+
}
31+
if !strings.Contains(preview, `"safe":"ok"`) {
32+
t.Fatalf("preview should keep non-sensitive keys: %q", preview)
33+
}
34+
}
35+
36+
func TestBuildToolArgumentsPreviewMaskNonJSONFallback(t *testing.T) {
37+
t.Parallel()
38+
39+
preview := buildToolArgumentsPreview(`token=abc password:xyz arg=ok`)
40+
if strings.Contains(preview, "abc") || strings.Contains(preview, "xyz") {
41+
t.Fatalf("preview leaked fallback credentials: %q", preview)
42+
}
43+
if !strings.Contains(preview, "token=***") {
44+
t.Fatalf("preview should mask token in fallback mode: %q", preview)
45+
}
46+
if !strings.Contains(preview, "password=***") {
47+
t.Fatalf("preview should mask password in fallback mode: %q", preview)
48+
}
49+
}
50+
51+
func TestBuildToolArgumentsPreviewTruncate(t *testing.T) {
52+
t.Parallel()
53+
54+
raw := strings.Repeat("a", hookToolArgumentsPreviewMaxChars+20)
55+
preview := buildToolArgumentsPreview(raw)
56+
if len([]rune(preview)) != hookToolArgumentsPreviewMaxChars {
57+
t.Fatalf("preview length=%d, want %d", len([]rune(preview)), hookToolArgumentsPreviewMaxChars)
58+
}
59+
}
60+
61+
func TestIsSensitiveHookToolArgumentKey(t *testing.T) {
62+
t.Parallel()
63+
64+
cases := []struct {
65+
key string
66+
want bool
67+
}{
68+
{key: "api_key", want: true},
69+
{key: "accessKey", want: true},
70+
{key: "authorization", want: true},
71+
{key: "auth_token", want: true},
72+
{key: "password", want: true},
73+
{key: "author", want: false},
74+
{key: "tool_name", want: false},
75+
}
76+
for _, tc := range cases {
77+
if got := isSensitiveHookToolArgumentKey(tc.key); got != tc.want {
78+
t.Fatalf("isSensitiveHookToolArgumentKey(%q)=%v, want %v", tc.key, got, tc.want)
79+
}
80+
}
81+
}

0 commit comments

Comments
 (0)