@@ -2,14 +2,160 @@ package inspect
22
33import (
44 "context"
5+ "fmt"
6+ "log/slog"
57 "regexp"
68 "strings"
79 "sync"
10+ "time"
811
912 "github.com/GrayCodeAI/inspect/internal/check"
1013 "github.com/GrayCodeAI/inspect/internal/crawler"
1114)
1215
16+ // Regex timeout constants for ReDoS protection.
17+ const (
18+ regexCompileTimeout = 1 * time .Second
19+ regexMatchTimeout = 100 * time .Millisecond
20+ )
21+
22+ // checkRegexComplexity rejects patterns likely to cause ReDoS. Returns an error
23+ // if the pattern contains nested quantifiers (e.g. (a+)+, (a*)*) or excessive
24+ // group nesting.
25+ func checkRegexComplexity (pattern string ) error {
26+ const maxDepth = 5
27+
28+ type groupInfo struct {
29+ hasQuantifier bool
30+ }
31+ var groupStack []groupInfo
32+ inQuantifierAfterGroup := false
33+
34+ for i := 0 ; i < len (pattern ); i ++ {
35+ ch := pattern [i ]
36+ switch ch {
37+ case '(' :
38+ if len (groupStack ) >= maxDepth {
39+ return fmt .Errorf ("regex group nesting depth exceeds maximum %d" , maxDepth )
40+ }
41+ groupStack = append (groupStack , groupInfo {})
42+ inQuantifierAfterGroup = false
43+ case ')' :
44+ if len (groupStack ) == 0 {
45+ inQuantifierAfterGroup = false
46+ continue
47+ }
48+ g := groupStack [len (groupStack )- 1 ]
49+ groupStack = groupStack [:len (groupStack )- 1 ]
50+ // After closing a group, mark that a quantifier after this ')'
51+ // would be a nested quantifier if the group itself contained one.
52+ inQuantifierAfterGroup = g .hasQuantifier
53+ case '*' , '+' :
54+ // These are quantifiers.
55+ if len (groupStack ) > 0 {
56+ groupStack [len (groupStack )- 1 ].hasQuantifier = true
57+ }
58+ if inQuantifierAfterGroup {
59+ return fmt .Errorf ("nested quantifier detected near position %d: quantifier after group containing a quantifier (pattern may cause ReDoS)" , i )
60+ }
61+ case '?' :
62+ // '?' after '(' or '|' is a group modifier, not a quantifier.
63+ // '?' after another quantifier (e.g. '+?') is a non-greedy modifier.
64+ if i > 0 {
65+ prev := pattern [i - 1 ]
66+ if prev != '(' && prev != '|' && prev != '*' && prev != '+' && prev != '?' {
67+ // This '?' is a quantifier (0-or-1).
68+ if len (groupStack ) > 0 {
69+ groupStack [len (groupStack )- 1 ].hasQuantifier = true
70+ }
71+ if inQuantifierAfterGroup {
72+ return fmt .Errorf ("nested quantifier detected near position %d: quantifier after group containing a quantifier (pattern may cause ReDoS)" , i )
73+ }
74+ }
75+ }
76+ inQuantifierAfterGroup = false
77+ case '{' :
78+ // '{' starts a counted repetition like {n}, {n,}, {n,m}.
79+ // This is a quantifier.
80+ if len (groupStack ) > 0 {
81+ groupStack [len (groupStack )- 1 ].hasQuantifier = true
82+ }
83+ if inQuantifierAfterGroup {
84+ return fmt .Errorf ("nested quantifier detected near position %d: quantifier after group containing a quantifier (pattern may cause ReDoS)" , i )
85+ }
86+ default :
87+ inQuantifierAfterGroup = false
88+ }
89+ }
90+ return nil
91+ }
92+
93+ // compileWithTimeout compiles a regex pattern with a timeout to protect against
94+ // pathological compilation times. Returns nil and an error if the pattern is
95+ // rejected by the complexity check or if compilation times out.
96+ func compileWithTimeout (pattern string ) (* regexp.Regexp , error ) {
97+ if err := checkRegexComplexity (pattern ); err != nil {
98+ return nil , err
99+ }
100+
101+ type result struct {
102+ re * regexp.Regexp
103+ err error
104+ }
105+ done := make (chan result , 1 )
106+ go func () {
107+ re , err := regexp .Compile (pattern )
108+ done <- result {re , err }
109+ }()
110+
111+ select {
112+ case res := <- done :
113+ return res .re , res .err
114+ case <- time .After (regexCompileTimeout ):
115+ return nil , fmt .Errorf ("regex compilation timed out after %s" , regexCompileTimeout )
116+ }
117+ }
118+
119+ // matchWithTimeout runs re.MatchString(s) with a timeout. Returns false if
120+ // the match does not complete in time, protecting against ReDoS at runtime.
121+ func matchWithTimeout (re * regexp.Regexp , s string ) bool {
122+ type result struct {
123+ matched bool
124+ }
125+ done := make (chan result , 1 )
126+ go func () {
127+ done <- result {matched : re .MatchString (s )}
128+ }()
129+
130+ select {
131+ case res := <- done :
132+ return res .matched
133+ case <- time .After (regexMatchTimeout ):
134+ slog .Warn ("regex match timed out, skipping" , "pattern" , re .String (), "timeout" , regexMatchTimeout )
135+ return false
136+ }
137+ }
138+
139+ // findWithTimeout runs re.FindString(s) with a timeout. Returns "" if
140+ // the match does not complete in time.
141+ func findWithTimeout (re * regexp.Regexp , s string ) string {
142+ type result struct {
143+ match string
144+ }
145+ done := make (chan result , 1 )
146+ go func () {
147+ done <- result {match : re .FindString (s )}
148+ }()
149+
150+ select {
151+ case res := <- done :
152+ return res .match
153+ case <- time .After (regexMatchTimeout ):
154+ slog .Warn ("regex find timed out, skipping" , "pattern" , re .String (), "timeout" , regexMatchTimeout )
155+ return ""
156+ }
157+ }
158+
13159// Checker is the public interface for custom checks. Implement this to add
14160// your own audit logic and register it with RegisterCheck.
15161type Checker interface {
@@ -95,23 +241,36 @@ type ruleCheckAdapter struct {
95241func newRuleCheckAdapter (rule RuleCheck ) * ruleCheckAdapter {
96242 a := & ruleCheckAdapter {rule : rule }
97243 if rule .URLMatch != "" {
98- a .urlRegex , _ = regexp .Compile (rule .URLMatch )
244+ re , err := compileWithTimeout (rule .URLMatch )
245+ if err != nil {
246+ slog .Error ("rule regex compilation failed (ReDoS protection)" , "rule" , rule .RuleName , "field" , "URLMatch" , "pattern" , rule .URLMatch , "error" , err )
247+ }
248+ a .urlRegex = re
99249 }
100250 a .headerRegexs = make (map [string ]* regexp.Regexp , len (rule .HeaderMatch ))
101251 for header , pattern := range rule .HeaderMatch {
102- if re , err := regexp .Compile (pattern ); err == nil {
103- a .headerRegexs [header ] = re
252+ re , err := compileWithTimeout (pattern )
253+ if err != nil {
254+ slog .Error ("rule regex compilation failed (ReDoS protection)" , "rule" , rule .RuleName , "field" , "HeaderMatch" , "header" , header , "pattern" , pattern , "error" , err )
255+ continue
104256 }
257+ a .headerRegexs [header ] = re
105258 }
106259 for _ , pattern := range rule .BodyMatch {
107- if re , err := regexp .Compile (pattern ); err == nil {
108- a .bodyRegexs = append (a .bodyRegexs , re )
260+ re , err := compileWithTimeout (pattern )
261+ if err != nil {
262+ slog .Error ("rule regex compilation failed (ReDoS protection)" , "rule" , rule .RuleName , "field" , "BodyMatch" , "pattern" , pattern , "error" , err )
263+ continue
109264 }
265+ a .bodyRegexs = append (a .bodyRegexs , re )
110266 }
111267 for _ , pattern := range rule .BodyMissing {
112- if re , err := regexp .Compile (pattern ); err == nil {
113- a .bodyMissing = append (a .bodyMissing , re )
268+ re , err := compileWithTimeout (pattern )
269+ if err != nil {
270+ slog .Error ("rule regex compilation failed (ReDoS protection)" , "rule" , rule .RuleName , "field" , "BodyMissing" , "pattern" , pattern , "error" , err )
271+ continue
114272 }
273+ a .bodyMissing = append (a .bodyMissing , re )
115274 }
116275 return a
117276}
@@ -128,7 +287,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
128287 if len (r .rule .StatusCodes ) > 0 && ! intIn (page .StatusCode , r .rule .StatusCodes ) {
129288 continue
130289 }
131- if r .urlRegex != nil && ! r .urlRegex . MatchString ( page .URL ) {
290+ if r .urlRegex != nil && ! matchWithTimeout ( r .urlRegex , page .URL ) {
132291 continue
133292 }
134293
@@ -150,7 +309,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
150309 if val == "" {
151310 continue
152311 }
153- if re . MatchString ( val ) {
312+ if matchWithTimeout ( re , val ) {
154313 findings = append (findings , check.Finding {
155314 Severity : check .Severity (r .rule .RuleSeverity ),
156315 URL : page .URL ,
@@ -165,7 +324,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
165324
166325 // Body match checks (regex match = bad)
167326 for _ , re := range r .bodyRegexs {
168- if loc := re . FindString ( body ); loc != "" {
327+ if loc := findWithTimeout ( re , body ); loc != "" {
169328 findings = append (findings , check.Finding {
170329 Severity : check .Severity (r .rule .RuleSeverity ),
171330 URL : page .URL ,
@@ -179,7 +338,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
179338
180339 // Body missing checks (pattern should be present but isn't)
181340 for _ , re := range r .bodyMissing {
182- if ! re . MatchString ( body ) {
341+ if ! matchWithTimeout ( re , body ) {
183342 findings = append (findings , check.Finding {
184343 Severity : check .Severity (r .rule .RuleSeverity ),
185344 URL : page .URL ,
@@ -237,20 +396,28 @@ func (a *customCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []c
237396 return internal
238397}
239398
240- func getCustomInternalChecks () []check. Checker {
241- customChecksMu . RLock ()
242- defer customChecksMu . RUnlock ()
243-
399+ // getCustomInternalChecks converts public Checker and RuleCheck slices into
400+ // internal check.Checker adapters. This accepts explicit slices so that
401+ // per-Scanner custom checks can be passed in without relying on global state.
402+ func getCustomInternalChecks ( checks [] Checker , rules [] RuleCheck ) []check. Checker {
244403 var result []check.Checker
245- for _ , c := range customChecks {
404+ for _ , c := range checks {
246405 result = append (result , & customCheckAdapter {checker : c })
247406 }
248- for _ , r := range customRules {
407+ for _ , r := range rules {
249408 result = append (result , newRuleCheckAdapter (r ))
250409 }
251410 return result
252411}
253412
413+ // getGlobalCustomInternalChecks returns checks registered via the global
414+ // RegisterCheck/RegisterRule functions. Kept for backward compatibility.
415+ func getGlobalCustomInternalChecks () []check.Checker {
416+ customChecksMu .RLock ()
417+ defer customChecksMu .RUnlock ()
418+ return getCustomInternalChecks (customChecks , customRules )
419+ }
420+
254421func intIn (n int , list []int ) bool {
255422 for _ , x := range list {
256423 if n == x {
@@ -262,8 +429,9 @@ func intIn(n int, list []int) bool {
262429
263430func truncateEvidence (s string , max int ) string {
264431 s = strings .ReplaceAll (s , "\n " , " " )
265- if len (s ) <= max {
432+ runes := []rune (s )
433+ if len (runes ) <= max {
266434 return s
267435 }
268- return s [:max ] + "..."
436+ return string ( runes [:max ]) + "..."
269437}
0 commit comments