Skip to content

Commit a1d036d

Browse files
committed
fix: add SSRF protection, SARIF output, per-check timeouts, and crawler hardening
- Add SSRF protection (scheme validation + private IP blocking) - Add SARIF 2.1.0 report format output - Add per-check timeout (30s default) - Add redirect loop detection and auth-required categorization - Fix HTML parse error handling (return errors instead of nil) - Fix rate limiter context cancellation - Cache compiled regexes in rule adapter constructor - Add Allow precedence in robots.txt, Crawl-Delay parsing - Add WithBlockPrivateIPs() option (private IPs allowed by default)
1 parent e39426a commit a1d036d

9 files changed

Lines changed: 558 additions & 74 deletions

File tree

cmd/inspect-ci/main.go

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"time"
1414

1515
"github.com/GrayCodeAI/inspect"
16+
reportpkg "github.com/GrayCodeAI/inspect/internal/report"
1617
)
1718

1819
func main() {
@@ -32,7 +33,7 @@ func main() {
3233
flag.IntVar(&depth, "depth", 5, "Maximum crawl depth")
3334
flag.StringVar(&failOn, "fail-on", "high", "Minimum severity to fail")
3435
flag.IntVar(&concurrency, "concurrency", 10, "Concurrent workers")
35-
flag.StringVar(&format, "format", "terminal", "Output format: terminal, json, junit")
36+
flag.StringVar(&format, "format", "terminal", "Output format: terminal, json, junit, sarif")
3637
flag.StringVar(&timeout, "timeout", "5m", "Scan timeout")
3738
flag.StringVar(&outputFile, "output-file", "", "Write report to file")
3839
flag.Parse()
@@ -71,6 +72,14 @@ func main() {
7172
case "json":
7273
data, _ := json.MarshalIndent(report, "", " ")
7374
output = string(data)
75+
case "sarif":
76+
rd := toReportData(report)
77+
sarif, sErr := reportpkg.FormatSARIF(rd)
78+
if sErr != nil {
79+
fmt.Fprintf(os.Stderr, "error: sarif format: %v\n", sErr)
80+
os.Exit(1)
81+
}
82+
output = sarif
7483
default:
7584
output = formatTerminal(report)
7685
}
@@ -104,6 +113,30 @@ func main() {
104113
}
105114
}
106115

116+
func toReportData(r *inspect.Report) reportpkg.ReportData {
117+
var rd reportpkg.ReportData
118+
rd.Target = r.Target
119+
rd.CrawledURLs = r.CrawledURLs
120+
rd.Duration = r.Duration
121+
rd.Stats.BySeverity = make(map[string]int)
122+
for sev, count := range r.Stats.BySeverity {
123+
rd.Stats.BySeverity[sev.String()] = count
124+
}
125+
rd.Stats.ByCheck = r.Stats.ByCheck
126+
for _, f := range r.Findings {
127+
rd.Findings = append(rd.Findings, reportpkg.Finding{
128+
Check: f.Check,
129+
Severity: reportpkg.Severity(f.Severity),
130+
URL: f.URL,
131+
Element: f.Element,
132+
Message: f.Message,
133+
Fix: f.Fix,
134+
Evidence: f.Evidence,
135+
})
136+
}
137+
return rd
138+
}
139+
107140
func formatTerminal(r *inspect.Report) string {
108141
var b strings.Builder
109142
b.WriteString(fmt.Sprintf("Inspect: %s — %d pages, %d findings\n",

internal/crawler/crawler.go

Lines changed: 175 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,22 @@ type Config struct {
3030
AuthHeader string
3131
AuthValue string
3232
CookieJar http.CookieJar
33+
AllowPrivateIPs bool // When true, skip SSRF protection for private IPs
3334
}
3435

3536
// Page represents a single crawled page with its metadata.
3637
type Page struct {
37-
URL string
38-
StatusCode int
39-
Headers http.Header
40-
Body []byte
41-
Links []Link
42-
Forms []Form
43-
Depth int
44-
ParentURL string
45-
Duration time.Duration
46-
Error error
38+
URL string
39+
StatusCode int
40+
Headers http.Header
41+
Body []byte
42+
Links []Link
43+
Forms []Form
44+
Depth int
45+
ParentURL string
46+
Duration time.Duration
47+
Error error
48+
AuthRequired bool // true when server returned 401/403
4749
}
4850

4951
// Link represents a hyperlink found on a page.
@@ -147,6 +149,10 @@ func (c *Crawler) Crawl(ctx context.Context, startURL string) ([]*Page, error) {
147149

148150
if c.cfg.RespectRobots {
149151
c.robots.Fetch(ctx, c.client, origin)
152+
// Apply Crawl-Delay from robots.txt if it's slower than current rate limit
153+
if delay := c.robots.CrawlDelay(origin); delay > 0 && delay > c.limiter.interval {
154+
c.limiter.interval = delay
155+
}
150156
}
151157

152158
// Seed with sitemap URLs if available
@@ -296,48 +302,181 @@ func (c *Crawler) fetch(ctx context.Context, targetURL string, depth int, parent
296302
}
297303

298304
func (c *Crawler) doFetch(ctx context.Context, page *Page, targetURL string) error {
305+
// SSRF protection: validate URL scheme and resolved IP
306+
if err := c.validateURL(targetURL); err != nil {
307+
page.Error = err
308+
return err
309+
}
310+
299311
pageCtx, cancel := context.WithTimeout(ctx, c.cfg.PageTimeout)
300312
defer cancel()
301313

302-
req, err := http.NewRequestWithContext(pageCtx, http.MethodGet, targetURL, nil)
303-
if err != nil {
304-
page.Error = err
305-
return err
314+
// Manual redirect handling with loop detection
315+
const maxRedirects = 10
316+
visited := make(map[string]bool)
317+
currentURL := targetURL
318+
319+
for redirectCount := 0; ; redirectCount++ {
320+
if redirectCount > maxRedirects {
321+
err := fmt.Errorf("too many redirects (max %d)", maxRedirects)
322+
page.Error = err
323+
return err
324+
}
325+
if visited[currentURL] {
326+
err := fmt.Errorf("redirect loop detected at %s", currentURL)
327+
page.Error = err
328+
return err
329+
}
330+
visited[currentURL] = true
331+
332+
req, err := http.NewRequestWithContext(pageCtx, http.MethodGet, currentURL, nil)
333+
if err != nil {
334+
page.Error = err
335+
return err
336+
}
337+
338+
req.Header.Set("User-Agent", c.cfg.UserAgent)
339+
if c.cfg.AuthHeader != "" {
340+
req.Header.Set(c.cfg.AuthHeader, c.cfg.AuthValue)
341+
}
342+
343+
// Use a client that does not follow redirects automatically
344+
resp, err := c.noRedirectClient().Do(req)
345+
if err != nil {
346+
page.Error = err
347+
return err
348+
}
349+
350+
// Handle auth-required responses as findings rather than errors
351+
if resp.StatusCode == 401 || resp.StatusCode == 403 {
352+
resp.Body.Close()
353+
page.StatusCode = resp.StatusCode
354+
page.Headers = resp.Header
355+
page.Error = nil
356+
page.AuthRequired = true
357+
return nil
358+
}
359+
360+
// Handle redirects manually
361+
if resp.StatusCode >= 300 && resp.StatusCode < 400 {
362+
resp.Body.Close()
363+
loc := resp.Header.Get("Location")
364+
if loc == "" {
365+
page.StatusCode = resp.StatusCode
366+
page.Headers = resp.Header
367+
return nil
368+
}
369+
resolved := resolveURL(currentURL, loc)
370+
if resolved == "" {
371+
page.StatusCode = resp.StatusCode
372+
page.Headers = resp.Header
373+
return nil
374+
}
375+
// Validate redirect target for SSRF
376+
if err := c.validateURL(resolved); err != nil {
377+
page.Error = err
378+
return err
379+
}
380+
currentURL = resolved
381+
continue
382+
}
383+
384+
// Non-redirect response
385+
page.StatusCode = resp.StatusCode
386+
page.Headers = resp.Header
387+
page.Error = nil
388+
389+
contentType := resp.Header.Get("Content-Type")
390+
if !strings.Contains(contentType, "text/html") {
391+
resp.Body.Close()
392+
return nil
393+
}
394+
395+
body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
396+
resp.Body.Close()
397+
if err != nil {
398+
page.Error = err
399+
return err
400+
}
401+
page.Body = body
402+
403+
page.Links = extractLinks(currentURL, body)
404+
page.Forms = extractForms(body)
405+
return nil
306406
}
407+
}
307408

308-
req.Header.Set("User-Agent", c.cfg.UserAgent)
309-
if c.cfg.AuthHeader != "" {
310-
req.Header.Set(c.cfg.AuthHeader, c.cfg.AuthValue)
409+
// noRedirectClient returns an HTTP client that does not follow redirects.
410+
func (c *Crawler) noRedirectClient() *http.Client {
411+
return &http.Client{
412+
Transport: c.client.Transport,
413+
Timeout: c.client.Timeout,
414+
Jar: c.client.Jar,
415+
CheckRedirect: func(req *http.Request, via []*http.Request) error {
416+
return http.ErrUseLastResponse
417+
},
311418
}
419+
}
312420

313-
resp, err := c.client.Do(req)
421+
// validateURL checks the URL for SSRF risks: scheme must be http/https and
422+
// resolved IP must not be in private ranges (unless AllowPrivateIPs is set).
423+
func (c *Crawler) validateURL(rawURL string) error {
424+
parsed, err := url.Parse(rawURL)
314425
if err != nil {
315-
page.Error = err
316-
return err
426+
return fmt.Errorf("invalid URL: %w", err)
317427
}
318-
defer resp.Body.Close()
319-
320-
page.StatusCode = resp.StatusCode
321-
page.Headers = resp.Header
322-
page.Error = nil
323-
324-
contentType := resp.Header.Get("Content-Type")
325-
if !strings.Contains(contentType, "text/html") {
428+
if parsed.Scheme != "http" && parsed.Scheme != "https" {
429+
return fmt.Errorf("disallowed URL scheme %q (only http/https allowed)", parsed.Scheme)
430+
}
431+
if c.cfg.AllowPrivateIPs {
326432
return nil
327433
}
328-
329-
body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024))
434+
host := parsed.Hostname()
435+
ips, err := net.LookupHost(host)
330436
if err != nil {
331-
page.Error = err
332-
return err
437+
// DNS resolution failure is not an SSRF issue; let the fetch handle it
438+
return nil
439+
}
440+
for _, ipStr := range ips {
441+
ip := net.ParseIP(ipStr)
442+
if ip == nil {
443+
continue
444+
}
445+
if isPrivateIP(ip) {
446+
return fmt.Errorf("SSRF protection: resolved IP %s for host %q is in a private range", ipStr, host)
447+
}
333448
}
334-
page.Body = body
335-
336-
page.Links = extractLinks(targetURL, body)
337-
page.Forms = extractForms(body)
338449
return nil
339450
}
340451

452+
// isPrivateIP checks if an IP is in a private/loopback range.
453+
func isPrivateIP(ip net.IP) bool {
454+
privateRanges := []struct {
455+
network *net.IPNet
456+
}{
457+
{mustParseCIDR("10.0.0.0/8")},
458+
{mustParseCIDR("172.16.0.0/12")},
459+
{mustParseCIDR("192.168.0.0/16")},
460+
{mustParseCIDR("127.0.0.0/8")},
461+
{mustParseCIDR("::1/128")},
462+
{mustParseCIDR("fc00::/7")},
463+
}
464+
for _, r := range privateRanges {
465+
if r.network.Contains(ip) {
466+
return true
467+
}
468+
}
469+
return false
470+
}
471+
472+
func mustParseCIDR(s string) *net.IPNet {
473+
_, network, err := net.ParseCIDR(s)
474+
if err != nil {
475+
panic(err)
476+
}
477+
return network
478+
}
479+
341480
func isRetryable(statusCode int) bool {
342481
return statusCode == 429 || statusCode == 500 || statusCode == 502 ||
343482
statusCode == 503 || statusCode == 504 || statusCode == 0

internal/crawler/crawler_test.go

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ func TestCrawl_Basic(t *testing.T) {
3131
RateLimit: 100,
3232
UserAgent: "test-bot",
3333
FollowRedirects: 3,
34+
AllowPrivateIPs: true,
3435
})
3536

3637
pages, err := c.Crawl(context.Background(), srv.URL)
@@ -65,12 +66,13 @@ func TestCrawl_DepthLimit(t *testing.T) {
6566
defer srv.Close()
6667

6768
c := New(Config{
68-
MaxDepth: 2,
69-
Concurrency: 1,
70-
Timeout: 10 * time.Second,
71-
PageTimeout: 5 * time.Second,
72-
RateLimit: 100,
73-
UserAgent: "test-bot",
69+
MaxDepth: 2,
70+
Concurrency: 1,
71+
Timeout: 10 * time.Second,
72+
PageTimeout: 5 * time.Second,
73+
RateLimit: 100,
74+
UserAgent: "test-bot",
75+
AllowPrivateIPs: true,
7476
})
7577

7678
pages, err := c.Crawl(context.Background(), srv.URL)
@@ -99,7 +101,7 @@ func TestCrawl_ExternalLinksNotFollowed(t *testing.T) {
99101
srv := httptest.NewServer(mux)
100102
defer srv.Close()
101103

102-
c := New(Config{MaxDepth: 3, Concurrency: 1, Timeout: 10 * time.Second, PageTimeout: 5 * time.Second, RateLimit: 100, UserAgent: "test"})
104+
c := New(Config{MaxDepth: 3, Concurrency: 1, Timeout: 10 * time.Second, PageTimeout: 5 * time.Second, RateLimit: 100, UserAgent: "test", AllowPrivateIPs: true})
103105
pages, err := c.Crawl(context.Background(), srv.URL)
104106
if err != nil {
105107
t.Fatalf("Crawl failed: %v", err)
@@ -139,7 +141,7 @@ func TestCrawl_ContextCancellation(t *testing.T) {
139141
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
140142
defer cancel()
141143

142-
c := New(Config{MaxDepth: 3, Concurrency: 2, Timeout: 10 * time.Second, PageTimeout: 5 * time.Second, RateLimit: 100, UserAgent: "test"})
144+
c := New(Config{MaxDepth: 3, Concurrency: 2, Timeout: 10 * time.Second, PageTimeout: 5 * time.Second, RateLimit: 100, UserAgent: "test", AllowPrivateIPs: true})
143145
_, err := c.Crawl(ctx, srv.URL)
144146
// Should not hang — should return within timeout
145147
if err != nil {
@@ -170,15 +172,16 @@ func TestCrawl_Retry(t *testing.T) {
170172
defer srv.Close()
171173

172174
c := New(Config{
173-
MaxDepth: 1,
174-
Concurrency: 1,
175-
Timeout: 10 * time.Second,
176-
PageTimeout: 5 * time.Second,
177-
RateLimit: 100,
178-
RetryAttempts: 3,
179-
RetryDelay: 10 * time.Millisecond,
180-
UserAgent: "test",
181-
RespectRobots: true,
175+
MaxDepth: 1,
176+
Concurrency: 1,
177+
Timeout: 10 * time.Second,
178+
PageTimeout: 5 * time.Second,
179+
RateLimit: 100,
180+
RetryAttempts: 3,
181+
RetryDelay: 10 * time.Millisecond,
182+
UserAgent: "test",
183+
RespectRobots: true,
184+
AllowPrivateIPs: true,
182185
})
183186

184187
pages, err := c.Crawl(context.Background(), srv.URL)

0 commit comments

Comments
 (0)