Skip to content

Commit 3264be9

Browse files
committed
security: comprehensive code review fixes - SSRF DNS rebinding, enable blockPrivateIPs, prevent auth header leakage, fix silent URL dropping, add MaxPages, move custom checks, parse HTML once, ReDoS protection, robots.txt parser, MCP server tests, LLM scanner tests
1 parent a05bba0 commit 3264be9

26 files changed

Lines changed: 3415 additions & 288 deletions

check.go

Lines changed: 187 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,160 @@ package inspect
22

33
import (
44
"context"
5+
"fmt"
6+
"log/slog"
57
"regexp"
68
"strings"
79
"sync"
10+
"time"
811

912
"github.com/GrayCodeAI/inspect/internal/check"
1013
"github.com/GrayCodeAI/inspect/internal/crawler"
1114
)
1215

16+
// Regex timeout constants for ReDoS protection.
17+
const (
18+
regexCompileTimeout = 1 * time.Second
19+
regexMatchTimeout = 100 * time.Millisecond
20+
)
21+
22+
// checkRegexComplexity rejects patterns likely to cause ReDoS. Returns an error
23+
// if the pattern contains nested quantifiers (e.g. (a+)+, (a*)*) or excessive
24+
// group nesting.
25+
func checkRegexComplexity(pattern string) error {
26+
const maxDepth = 5
27+
28+
type groupInfo struct {
29+
hasQuantifier bool
30+
}
31+
var groupStack []groupInfo
32+
inQuantifierAfterGroup := false
33+
34+
for i := 0; i < len(pattern); i++ {
35+
ch := pattern[i]
36+
switch ch {
37+
case '(':
38+
if len(groupStack) >= maxDepth {
39+
return fmt.Errorf("regex group nesting depth exceeds maximum %d", maxDepth)
40+
}
41+
groupStack = append(groupStack, groupInfo{})
42+
inQuantifierAfterGroup = false
43+
case ')':
44+
if len(groupStack) == 0 {
45+
inQuantifierAfterGroup = false
46+
continue
47+
}
48+
g := groupStack[len(groupStack)-1]
49+
groupStack = groupStack[:len(groupStack)-1]
50+
// After closing a group, mark that a quantifier after this ')'
51+
// would be a nested quantifier if the group itself contained one.
52+
inQuantifierAfterGroup = g.hasQuantifier
53+
case '*', '+':
54+
// These are quantifiers.
55+
if len(groupStack) > 0 {
56+
groupStack[len(groupStack)-1].hasQuantifier = true
57+
}
58+
if inQuantifierAfterGroup {
59+
return fmt.Errorf("nested quantifier detected near position %d: quantifier after group containing a quantifier (pattern may cause ReDoS)", i)
60+
}
61+
case '?':
62+
// '?' after '(' or '|' is a group modifier, not a quantifier.
63+
// '?' after another quantifier (e.g. '+?') is a non-greedy modifier.
64+
if i > 0 {
65+
prev := pattern[i-1]
66+
if prev != '(' && prev != '|' && prev != '*' && prev != '+' && prev != '?' {
67+
// This '?' is a quantifier (0-or-1).
68+
if len(groupStack) > 0 {
69+
groupStack[len(groupStack)-1].hasQuantifier = true
70+
}
71+
if inQuantifierAfterGroup {
72+
return fmt.Errorf("nested quantifier detected near position %d: quantifier after group containing a quantifier (pattern may cause ReDoS)", i)
73+
}
74+
}
75+
}
76+
inQuantifierAfterGroup = false
77+
case '{':
78+
// '{' starts a counted repetition like {n}, {n,}, {n,m}.
79+
// This is a quantifier.
80+
if len(groupStack) > 0 {
81+
groupStack[len(groupStack)-1].hasQuantifier = true
82+
}
83+
if inQuantifierAfterGroup {
84+
return fmt.Errorf("nested quantifier detected near position %d: quantifier after group containing a quantifier (pattern may cause ReDoS)", i)
85+
}
86+
default:
87+
inQuantifierAfterGroup = false
88+
}
89+
}
90+
return nil
91+
}
92+
93+
// compileWithTimeout compiles a regex pattern with a timeout to protect against
94+
// pathological compilation times. Returns nil and an error if the pattern is
95+
// rejected by the complexity check or if compilation times out.
96+
func compileWithTimeout(pattern string) (*regexp.Regexp, error) {
97+
if err := checkRegexComplexity(pattern); err != nil {
98+
return nil, err
99+
}
100+
101+
type result struct {
102+
re *regexp.Regexp
103+
err error
104+
}
105+
done := make(chan result, 1)
106+
go func() {
107+
re, err := regexp.Compile(pattern)
108+
done <- result{re, err}
109+
}()
110+
111+
select {
112+
case res := <-done:
113+
return res.re, res.err
114+
case <-time.After(regexCompileTimeout):
115+
return nil, fmt.Errorf("regex compilation timed out after %s", regexCompileTimeout)
116+
}
117+
}
118+
119+
// matchWithTimeout runs re.MatchString(s) with a timeout. Returns false if
120+
// the match does not complete in time, protecting against ReDoS at runtime.
121+
func matchWithTimeout(re *regexp.Regexp, s string) bool {
122+
type result struct {
123+
matched bool
124+
}
125+
done := make(chan result, 1)
126+
go func() {
127+
done <- result{matched: re.MatchString(s)}
128+
}()
129+
130+
select {
131+
case res := <-done:
132+
return res.matched
133+
case <-time.After(regexMatchTimeout):
134+
slog.Warn("regex match timed out, skipping", "pattern", re.String(), "timeout", regexMatchTimeout)
135+
return false
136+
}
137+
}
138+
139+
// findWithTimeout runs re.FindString(s) with a timeout. Returns "" if
140+
// the match does not complete in time.
141+
func findWithTimeout(re *regexp.Regexp, s string) string {
142+
type result struct {
143+
match string
144+
}
145+
done := make(chan result, 1)
146+
go func() {
147+
done <- result{match: re.FindString(s)}
148+
}()
149+
150+
select {
151+
case res := <-done:
152+
return res.match
153+
case <-time.After(regexMatchTimeout):
154+
slog.Warn("regex find timed out, skipping", "pattern", re.String(), "timeout", regexMatchTimeout)
155+
return ""
156+
}
157+
}
158+
13159
// Checker is the public interface for custom checks. Implement this to add
14160
// your own audit logic and register it with RegisterCheck.
15161
type Checker interface {
@@ -95,23 +241,36 @@ type ruleCheckAdapter struct {
95241
func newRuleCheckAdapter(rule RuleCheck) *ruleCheckAdapter {
96242
a := &ruleCheckAdapter{rule: rule}
97243
if rule.URLMatch != "" {
98-
a.urlRegex, _ = regexp.Compile(rule.URLMatch)
244+
re, err := compileWithTimeout(rule.URLMatch)
245+
if err != nil {
246+
slog.Error("rule regex compilation failed (ReDoS protection)", "rule", rule.RuleName, "field", "URLMatch", "pattern", rule.URLMatch, "error", err)
247+
}
248+
a.urlRegex = re
99249
}
100250
a.headerRegexs = make(map[string]*regexp.Regexp, len(rule.HeaderMatch))
101251
for header, pattern := range rule.HeaderMatch {
102-
if re, err := regexp.Compile(pattern); err == nil {
103-
a.headerRegexs[header] = re
252+
re, err := compileWithTimeout(pattern)
253+
if err != nil {
254+
slog.Error("rule regex compilation failed (ReDoS protection)", "rule", rule.RuleName, "field", "HeaderMatch", "header", header, "pattern", pattern, "error", err)
255+
continue
104256
}
257+
a.headerRegexs[header] = re
105258
}
106259
for _, pattern := range rule.BodyMatch {
107-
if re, err := regexp.Compile(pattern); err == nil {
108-
a.bodyRegexs = append(a.bodyRegexs, re)
260+
re, err := compileWithTimeout(pattern)
261+
if err != nil {
262+
slog.Error("rule regex compilation failed (ReDoS protection)", "rule", rule.RuleName, "field", "BodyMatch", "pattern", pattern, "error", err)
263+
continue
109264
}
265+
a.bodyRegexs = append(a.bodyRegexs, re)
110266
}
111267
for _, pattern := range rule.BodyMissing {
112-
if re, err := regexp.Compile(pattern); err == nil {
113-
a.bodyMissing = append(a.bodyMissing, re)
268+
re, err := compileWithTimeout(pattern)
269+
if err != nil {
270+
slog.Error("rule regex compilation failed (ReDoS protection)", "rule", rule.RuleName, "field", "BodyMissing", "pattern", pattern, "error", err)
271+
continue
114272
}
273+
a.bodyMissing = append(a.bodyMissing, re)
115274
}
116275
return a
117276
}
@@ -128,7 +287,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
128287
if len(r.rule.StatusCodes) > 0 && !intIn(page.StatusCode, r.rule.StatusCodes) {
129288
continue
130289
}
131-
if r.urlRegex != nil && !r.urlRegex.MatchString(page.URL) {
290+
if r.urlRegex != nil && !matchWithTimeout(r.urlRegex, page.URL) {
132291
continue
133292
}
134293

@@ -150,7 +309,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
150309
if val == "" {
151310
continue
152311
}
153-
if re.MatchString(val) {
312+
if matchWithTimeout(re, val) {
154313
findings = append(findings, check.Finding{
155314
Severity: check.Severity(r.rule.RuleSeverity),
156315
URL: page.URL,
@@ -165,7 +324,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
165324

166325
// Body match checks (regex match = bad)
167326
for _, re := range r.bodyRegexs {
168-
if loc := re.FindString(body); loc != "" {
327+
if loc := findWithTimeout(re, body); loc != "" {
169328
findings = append(findings, check.Finding{
170329
Severity: check.Severity(r.rule.RuleSeverity),
171330
URL: page.URL,
@@ -179,7 +338,7 @@ func (r *ruleCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []che
179338

180339
// Body missing checks (pattern should be present but isn't)
181340
for _, re := range r.bodyMissing {
182-
if !re.MatchString(body) {
341+
if !matchWithTimeout(re, body) {
183342
findings = append(findings, check.Finding{
184343
Severity: check.Severity(r.rule.RuleSeverity),
185344
URL: page.URL,
@@ -237,20 +396,28 @@ func (a *customCheckAdapter) Run(ctx context.Context, pages []*crawler.Page) []c
237396
return internal
238397
}
239398

240-
func getCustomInternalChecks() []check.Checker {
241-
customChecksMu.RLock()
242-
defer customChecksMu.RUnlock()
243-
399+
// getCustomInternalChecks converts public Checker and RuleCheck slices into
400+
// internal check.Checker adapters. This accepts explicit slices so that
401+
// per-Scanner custom checks can be passed in without relying on global state.
402+
func getCustomInternalChecks(checks []Checker, rules []RuleCheck) []check.Checker {
244403
var result []check.Checker
245-
for _, c := range customChecks {
404+
for _, c := range checks {
246405
result = append(result, &customCheckAdapter{checker: c})
247406
}
248-
for _, r := range customRules {
407+
for _, r := range rules {
249408
result = append(result, newRuleCheckAdapter(r))
250409
}
251410
return result
252411
}
253412

413+
// getGlobalCustomInternalChecks returns checks registered via the global
414+
// RegisterCheck/RegisterRule functions. Kept for backward compatibility.
415+
func getGlobalCustomInternalChecks() []check.Checker {
416+
customChecksMu.RLock()
417+
defer customChecksMu.RUnlock()
418+
return getCustomInternalChecks(customChecks, customRules)
419+
}
420+
254421
func intIn(n int, list []int) bool {
255422
for _, x := range list {
256423
if n == x {
@@ -262,8 +429,9 @@ func intIn(n int, list []int) bool {
262429

263430
func truncateEvidence(s string, max int) string {
264431
s = strings.ReplaceAll(s, "\n", " ")
265-
if len(s) <= max {
432+
runes := []rune(s)
433+
if len(runes) <= max {
266434
return s
267435
}
268-
return s[:max] + "..."
436+
return string(runes[:max]) + "..."
269437
}

0 commit comments

Comments
 (0)