From ce31f6d52c0dd736d6232a6f7ba444ab0c7d45a0 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 20:59:51 +0530 Subject: [PATCH 1/4] test(inspect): split check_test.go into focused test files Mechanical refactor for code clarity. check_test.go (2997 LOC) split into three same-package files at top-level function boundaries with no behavior, API, or exported-symbol changes: - check_test.go (1130 LOC): shared makePage helper + core check tests - check_more_test.go (972 LOC): registry and per-check unit tests - check_extra_test.go (920 LOC): interface conformance, edge cases, helpers --- internal/check/check_extra_test.go | 920 ++++++++++++++ internal/check/check_more_test.go | 972 +++++++++++++++ internal/check/check_test.go | 1871 +--------------------------- 3 files changed, 1894 insertions(+), 1869 deletions(-) create mode 100644 internal/check/check_extra_test.go create mode 100644 internal/check/check_more_test.go diff --git a/internal/check/check_extra_test.go b/internal/check/check_extra_test.go new file mode 100644 index 0000000..a1ed36c --- /dev/null +++ b/internal/check/check_extra_test.go @@ -0,0 +1,920 @@ +package check + +import ( + "context" + "strings" + "testing" + + "github.com/GrayCodeAI/inspect/internal/crawler" +) + +// This file was split out of check_test.go for readability (mechanical move; no behavior change). + +// ===================================================================== +// Checker interface compliance tests +// ===================================================================== + +func TestCheckerInterface_AllChecksImplement(t *testing.T) { + // Verify each concrete check type satisfies the Checker interface at compile time. + // This test exercises Name() and verifies the interface is implemented. + var _ Checker = &LinksCheck{} + var _ Checker = &SecurityCheck{} + var _ Checker = &FormsCheck{} + var _ Checker = &A11yCheck{} + var _ Checker = &PerfCheck{} + var _ Checker = &SEOCheck{} + var _ Checker = &SRICheck{} + var _ Checker = &AIReadyCheck{} + var _ Checker = &ReachabilityCheck{} +} + +func TestCheckerInterface_Names(t *testing.T) { + checks := map[string]Checker{ + "links": &LinksCheck{}, + "security": &SecurityCheck{}, + "forms": &FormsCheck{}, + "a11y": &A11yCheck{}, + "perf": &PerfCheck{}, + "seo": &SEOCheck{}, + "sri": &SRICheck{}, + "aiready": &AIReadyCheck{}, + "reachability": &ReachabilityCheck{}, + } + for expected, chk := range checks { + if chk.Name() != expected { + t.Errorf("%T.Name() = %q, want %q", chk, chk.Name(), expected) + } + } +} + +func TestCheckerInterface_RunReturnsSlice(t *testing.T) { + // Verify Run() returns a non-nil slice for an empty input (no panic). + ctx := context.Background() + checks := []Checker{ + &LinksCheck{}, + &SecurityCheck{}, + &FormsCheck{}, + &A11yCheck{}, + &PerfCheck{}, + &SEOCheck{}, + &SRICheck{}, + &AIReadyCheck{}, + &ReachabilityCheck{}, + } + for _, chk := range checks { + findings := chk.Run(ctx, nil) + // findings may be nil for empty input, that's OK + _ = findings + } +} + +// ===================================================================== +// Registry comprehensive tests +// ===================================================================== + +func TestDefaultRegistry_AllNames(t *testing.T) { + r := DefaultRegistry() + expectedNames := []string{ + "links", "security", "forms", "a11y", "perf", "seo", "sri", "aiready", "reachability", + } + all := r.All() + nameSet := make(map[string]bool) + for _, c := range all { + nameSet[c.Name()] = true + } + for _, name := range expectedNames { + if !nameSet[name] { + t.Errorf("DefaultRegistry missing expected check %q", name) + } + } + if len(all) != len(expectedNames) { + t.Errorf("DefaultRegistry has %d checks, expected %d", len(all), len(expectedNames)) + } +} + +func TestRegistry_Filter_MultipleSelections(t *testing.T) { + r := DefaultRegistry() + filtered := r.Filter([]string{"perf", "seo", "sri"}) + if len(filtered) != 3 { + t.Errorf("expected 3 checks, got %d", len(filtered)) + } + names := make(map[string]bool) + for _, c := range filtered { + names[c.Name()] = true + } + for _, want := range []string{"perf", "seo", "sri"} { + if !names[want] { + t.Errorf("missing expected check %q in filtered results", want) + } + } +} + +func TestRegistry_Filter_NonExistentName(t *testing.T) { + r := DefaultRegistry() + filtered := r.Filter([]string{"nonexistent"}) + if len(filtered) != 0 { + t.Errorf("expected 0 checks for nonexistent name, got %d", len(filtered)) + } +} + +func TestRegistry_Filter_Mix(t *testing.T) { + r := DefaultRegistry() + filtered := r.Filter([]string{"links", "nonexistent", "seo"}) + if len(filtered) != 2 { + t.Errorf("expected 2 checks (links + seo), got %d", len(filtered)) + } +} + +// ===================================================================== +// LinksCheck additional tests +// ===================================================================== + +func TestLinksCheck_RelativeURLs(t *testing.T) { + page := makePage("https://example.com/page", 200, map[string]string{"Content-Type": "text/html"}, + ` + About + Parent + Sibling + `) + aboutPage := makePage("https://example.com/about", 200, map[string]string{"Content-Type": "text/html"}, + `About`) + parentPage := makePage("https://example.com/parent", 200, map[string]string{"Content-Type": "text/html"}, + `Parent`) + siblingPage := makePage("https://example.com/sibling", 200, map[string]string{"Content-Type": "text/html"}, + `Sibling`) + + chk := &LinksCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page, aboutPage, parentPage, siblingPage}) + + // All pages are reachable so no broken-link findings expected + for _, f := range findings { + if strings.Contains(f.Message, "HTTP 404") || strings.Contains(f.Message, "HTTP 500") { + t.Errorf("unexpected broken link finding: %s", f.Message) + } + } +} + +func TestLinksCheck_EmptyPage(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + chk := &LinksCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + // Should not panic, may or may not have findings + _ = findings +} + +func TestLinksCheck_ErrorPage(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{}, "") + page.Error = context.Canceled + + chk := &LinksCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should skip error pages") + } +} + +func TestLinksCheck_NoPages(t *testing.T) { + chk := &LinksCheck{} + findings := chk.Run(context.Background(), nil) + if len(findings) != 0 { + t.Errorf("expected no findings for nil pages, got %d", len(findings)) + } +} + +func TestResolveLink(t *testing.T) { + tests := []struct { + base, href, want string + }{ + {"https://example.com/page", "/about", "https://example.com/about"}, + {"https://example.com/page", "https://other.com/x", "https://other.com/x"}, + {"https://example.com/page", "", ""}, + {"https://example.com/a/b", "../c", "https://example.com/c"}, + {"https://example.com/a/b", "c", "https://example.com/a/c"}, + } + for _, tt := range tests { + got := resolveLink(tt.base, tt.href) + if got != tt.want { + t.Errorf("resolveLink(%q, %q) = %q, want %q", tt.base, tt.href, got, tt.want) + } + } +} + +// ===================================================================== +// SecurityCheck additional tests +// ===================================================================== + +func TestSecurityCheck_CSPUnsafeInline(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "Content-Security-Policy": "default-src 'self'; script-src 'unsafe-inline'", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Strict-Transport-Security": "max-age=31536000", + "Referrer-Policy": "strict-origin", + "Permissions-Policy": "camera=()", + }, `test`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "unsafe-inline") { + found = true + } + } + if !found { + t.Error("expected finding for unsafe-inline in script-src") + } +} + +func TestSecurityCheck_CSPUnsafeEval(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "Content-Security-Policy": "script-src 'unsafe-eval'", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Strict-Transport-Security": "max-age=31536000", + "Referrer-Policy": "strict-origin", + "Permissions-Policy": "camera=()", + }, `test`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "unsafe-eval") { + found = true + } + } + if !found { + t.Error("expected finding for unsafe-eval in script-src") + } +} + +func TestSecurityCheck_CSPWildcard(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "Content-Security-Policy": "default-src *", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Strict-Transport-Security": "max-age=31536000", + "Referrer-Policy": "strict-origin", + "Permissions-Policy": "camera=()", + }, `test`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "wildcard") { + found = true + } + } + if !found { + t.Error("expected finding for wildcard in CSP") + } +} + +func TestSecurityCheck_CSPMissingFrameAncestors(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "Content-Security-Policy": "default-src 'self'", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Strict-Transport-Security": "max-age=31536000", + "Referrer-Policy": "strict-origin", + "Permissions-Policy": "camera=()", + }, `test`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "frame-ancestors") { + found = true + } + } + if !found { + t.Error("expected finding for missing frame-ancestors directive") + } +} + +func TestSecurityCheck_ExposedSecrets(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "Content-Security-Policy": "default-src 'self'; frame-ancestors 'self'", + "X-Content-Type-Options": "nosniff", + "X-Frame-Options": "DENY", + "Strict-Transport-Security": "max-age=31536000", + "Referrer-Policy": "strict-origin", + "Permissions-Policy": "camera=()", + }, `var api_key = "sk-abcdefghijklmnopqrstuvwxyz1234567890"`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "secret") || strings.Contains(f.Message, "credential") { + found = true + if f.Severity != SeverityCritical { + t.Errorf("expected critical severity for exposed secrets, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected finding for exposed API key/secret") + } +} + +func TestSecurityCheck_Name(t *testing.T) { + chk := &SecurityCheck{} + if chk.Name() != "security" { + t.Errorf("expected 'security', got %q", chk.Name()) + } +} + +// ===================================================================== +// A11yCheck additional tests +// ===================================================================== + +func TestA11yCheck_MissingLabels(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `Test
+ +
`) + + chk := &A11yCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "label") { + found = true + } + } + if !found { + t.Error("expected finding for input missing label") + } +} + +func TestA11yCheck_LabelWithFor(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `Test
+ + +
`) + + chk := &A11yCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + for _, f := range findings { + if strings.Contains(f.Message, "label") && strings.Contains(f.Element, "email") { + t.Errorf("should not flag input with associated label: %s", f.Message) + } + } +} + +func TestA11yCheck_MissingLang(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `Test

Content

`) + + chk := &A11yCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "lang") { + found = true + } + } + if !found { + t.Error("expected finding for missing lang attribute") + } +} + +func TestA11yCheck_EmptyLinkText(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `Test
+ +
`) + + chk := &A11yCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "no accessible text") { + found = true + } + } + if !found { + t.Error("expected finding for link with no accessible text") + } +} + +func TestA11yCheck_SkipsErrorPages(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + page.Error = context.Canceled + + chk := &A11yCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + // Advanced A11y also skips error pages, so total should be 0 + if len(findings) != 0 { + t.Error("should skip error pages") + } +} + +func TestA11yCheck_EmptyBody(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + + chk := &A11yCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should not produce findings for empty body") + } +} + +// ===================================================================== +// PerfCheck additional tests +// ===================================================================== + +func TestPerfCheck_NoPages(t *testing.T) { + chk := &PerfCheck{} + findings := chk.Run(context.Background(), nil) + if len(findings) != 0 { + t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) + } +} + +func TestPerfCheck_EmptyBody(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + if len(findings) != 0 { + t.Error("should not produce findings for empty body") + } +} + +func TestPerfCheck_LargeImages(t *testing.T) { + // Image without dimensions should trigger a finding + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html", "Content-Encoding": "gzip", "Cache-Control": "max-age=3600"}, + `T`) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "missing width/height") { + found = true + } + } + if !found { + t.Error("expected finding for image missing dimensions") + } +} + +// ===================================================================== +// SEOCheck additional tests +// ===================================================================== + +func TestSEOCheck_MissingTitle(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "missing ") || strings.Contains(f.Message, "missing <title> tag") || strings.Contains(f.Message, "Page missing <title>") { + found = true + } + } + if !found { + t.Error("expected finding for missing title tag") + } +} + +func TestSEOCheck_MissingDescription(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `<html><head><title>Page`) + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Missing meta description") { + found = true + } + } + if !found { + t.Error("expected finding for missing meta description") + } +} + +func TestSEOCheck_NoPages(t *testing.T) { + chk := &SEOCheck{} + findings := chk.Run(context.Background(), nil) + if len(findings) != 0 { + t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) + } +} + +// ===================================================================== +// FormsCheck additional tests +// ===================================================================== + +func TestFormsCheck_NoPages(t *testing.T) { + chk := &FormsCheck{} + findings := chk.Run(context.Background(), nil) + if len(findings) != 0 { + t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) + } +} + +func TestFormsCheck_NoForms(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `

No forms here

`) + + chk := &FormsCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + if len(findings) != 0 { + t.Errorf("expected 0 findings for page with no forms, got %d", len(findings)) + } +} + +func TestFormsCheck_ValidForm(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + page.Forms = []crawler.Form{ + {Action: "/submit", Method: "POST", HasCSRF: true}, + } + + chk := &FormsCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + for _, f := range findings { + if strings.Contains(f.Message, "CSRF") || strings.Contains(f.Message, "no action") { + t.Errorf("unexpected finding for valid form: %s", f.Message) + } + } +} + +// ===================================================================== +// SRICheck additional tests +// ===================================================================== + +func TestSRICheck_NoPages(t *testing.T) { + chk := &SRICheck{} + findings := chk.Run(context.Background(), nil) + if len(findings) != 0 { + t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) + } +} + +// ===================================================================== +// AIReadyCheck additional tests +// ===================================================================== + +func TestAIReadyCheck_NoPages(t *testing.T) { + chk := &AIReadyCheck{} + findings := chk.Run(context.Background(), nil) + // Should return llms.txt and sitemap findings based on empty page list + if len(findings) == 0 { + t.Log("AIReadyCheck with nil pages returns findings about missing llms.txt and sitemap") + } +} + +func TestAIReadyCheck_ErrorPage(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + page.Error = context.Canceled + + chk := &AIReadyCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + // Error pages are skipped for per-page checks, but llms.txt/sitemap findings may still appear + for _, f := range findings { + if strings.Contains(f.Message, "markdown alternate") || strings.Contains(f.Message, "structured data") { + t.Errorf("should skip error page for per-page checks: %s", f.Message) + } + } +} + +// ===================================================================== +// ReachabilityCheck tests +// ===================================================================== + +func TestReachabilityCheck_Name(t *testing.T) { + chk := &ReachabilityCheck{} + if chk.Name() != "reachability" { + t.Errorf("expected 'reachability', got %q", chk.Name()) + } +} + +func TestReachabilityCheck_EmptyBody(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + + chk := &ReachabilityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should not produce findings for empty body") + } +} + +func TestReachabilityCheck_ErrorPage(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + page.Error = context.Canceled + + chk := &ReachabilityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should skip error pages") + } +} + +func TestReachabilityCheck_DataURI(t *testing.T) { + // data: URIs should be skipped (not checked for reachability) + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + chk := &ReachabilityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + for _, f := range findings { + if strings.Contains(f.Message, "data:") { + t.Error("should not check reachability of data: URIs") + } + } +} + +func TestReachabilityCheck_NoPages(t *testing.T) { + chk := &ReachabilityCheck{} + findings := chk.Run(context.Background(), nil) + if len(findings) != 0 { + t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) + } +} + +func TestExtractResourceRefs(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ` + + + + + + + `) + + refs := extractResourceRefs(page) + + typeCounts := map[string]int{} + for _, ref := range refs { + typeCounts[ref.Resource]++ + } + + if typeCounts["script"] != 1 { + t.Errorf("expected 1 script ref, got %d", typeCounts["script"]) + } + if typeCounts["stylesheet"] != 1 { + t.Errorf("expected 1 stylesheet ref, got %d", typeCounts["stylesheet"]) + } + if typeCounts["image"] != 1 { + t.Errorf("expected 1 image ref, got %d", typeCounts["image"]) + } + if typeCounts["media"] != 1 { + t.Errorf("expected 1 media ref (video), got %d", typeCounts["media"]) + } + if typeCounts["iframe"] != 1 { + t.Errorf("expected 1 iframe ref, got %d", typeCounts["iframe"]) + } +} + +func TestExtractResourceRefs_EmptyBody(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + refs := extractResourceRefs(page) + if len(refs) != 0 { + t.Errorf("expected 0 refs for empty body, got %d", len(refs)) + } +} + +func TestResolveURL(t *testing.T) { + tests := []struct { + base, href, want string + }{ + {"https://example.com/page", "/img/photo.jpg", "https://example.com/img/photo.jpg"}, + {"https://example.com/page", "https://cdn.example.com/lib.js", "https://cdn.example.com/lib.js"}, + {"https://example.com/page", "", ""}, + {"https://example.com/a/b", "../img.jpg", "https://example.com/img.jpg"}, + } + for _, tt := range tests { + got := resolveURL(tt.base, tt.href) + if got != tt.want { + t.Errorf("resolveURL(%q, %q) = %q, want %q", tt.base, tt.href, got, tt.want) + } + } +} + +func TestSeverityForResourceStatus(t *testing.T) { + if severityForResourceStatus(404) != SeverityHigh { + t.Error("404 should be high") + } + if severityForResourceStatus(500) != SeverityCritical { + t.Error("500 should be critical") + } + if severityForResourceStatus(403) != SeverityMedium { + t.Error("403 should be medium") + } + if severityForResourceStatus(301) != SeverityLow { + t.Error("301 should be low") + } +} + +// ===================================================================== +// Edge cases: empty pages, no HTML body, error pages +// ===================================================================== + +func TestAllChecks_EmptyPagesSlice(t *testing.T) { + ctx := context.Background() + checks := []Checker{ + &LinksCheck{}, + &SecurityCheck{}, + &FormsCheck{}, + &A11yCheck{}, + &PerfCheck{}, + &SEOCheck{}, + &SRICheck{}, + &AIReadyCheck{}, + &ReachabilityCheck{}, + } + for _, chk := range checks { + // Should not panic with nil pages + findings := chk.Run(ctx, nil) + if findings == nil { + findings = []Finding{} + } + t.Logf("%s: %d findings on nil pages", chk.Name(), len(findings)) + } +} + +func TestAllChecks_ErrorPages(t *testing.T) { + ctx := context.Background() + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
`) + page.Error = context.Canceled + + checks := []Checker{ + &LinksCheck{}, + &SecurityCheck{}, + &FormsCheck{}, + &A11yCheck{}, + &PerfCheck{}, + &SEOCheck{}, + &SRICheck{}, + &AIReadyCheck{}, + &ReachabilityCheck{}, + } + for _, chk := range checks { + findings := chk.Run(ctx, []*crawler.Page{page}) + // Most checks should skip error pages (links may still report status) + t.Logf("%s: %d findings on error page", chk.Name(), len(findings)) + } +} + +func TestAllChecks_NonHTMLBody(t *testing.T) { + ctx := context.Background() + page := makePage("https://example.com/data.json", 200, + map[string]string{"Content-Type": "application/json"}, + `{"key": "value"}`) + + checks := []Checker{ + &SecurityCheck{}, + &FormsCheck{}, + &PerfCheck{}, + &SEOCheck{}, + &SRICheck{}, + } + for _, chk := range checks { + findings := chk.Run(ctx, []*crawler.Page{page}) + t.Logf("%s: %d findings on JSON page", chk.Name(), len(findings)) + } +} + +func TestAllChecks_MultiplePages(t *testing.T) { + ctx := context.Background() + pages := []*crawler.Page{ + makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `Home

Home

`), + makePage("https://example.com/about", 200, map[string]string{"Content-Type": "text/html"}, + `About

About

`), + makePage("https://example.com/contact", 200, map[string]string{"Content-Type": "text/html"}, + `Contact

Contact

`), + } + + checks := []Checker{ + &A11yCheck{}, + &PerfCheck{}, + &SEOCheck{}, + &SRICheck{}, + } + for _, chk := range checks { + findings := chk.Run(ctx, pages) + t.Logf("%s: %d findings across %d pages", chk.Name(), len(findings), len(pages)) + } +} + +// ===================================================================== +// Helper function tests +// ===================================================================== + +func TestContainsVersion(t *testing.T) { + if !containsVersion("Apache/2.4.51") { + t.Error("should detect version in Apache/2.4.51") + } + if !containsVersion("nginx/1.21.0") { + t.Error("should detect version in nginx/1.21.0") + } + if containsVersion("Apache") { + t.Error("should not detect version in plain Apache") + } + if containsVersion("") { + t.Error("should not detect version in empty string") + } +} + +func TestTruncate(t *testing.T) { + if truncate("short", 10) != "short" { + t.Error("short string should not be truncated") + } + if truncate("a long string here", 6) != "a long..." { + t.Errorf("expected truncation, got %q", truncate("a long string here", 6)) + } + if truncate("", 5) != "" { + t.Error("empty string should remain empty") + } +} + +func TestTruncateResRef(t *testing.T) { + if truncateResRef("/short.js", 80) != "/short.js" { + t.Error("short ref should not be truncated") + } + long := make([]byte, 100) + for i := range long { + long[i] = 'a' + } + got := truncateResRef(string(long), 80) + if len(got) != 83 { // 80 + "..." + t.Errorf("expected truncation to 83 chars, got %d", len(got)) + } +} + +func TestNormalizeForLookup(t *testing.T) { + tests := []struct { + input, want string + }{ + {"https://example.com/path/", "https://example.com/path"}, + {"https://example.com/path#frag", "https://example.com/path"}, + {"https://example.com", "https://example.com"}, + } + for _, tt := range tests { + got := normalizeForLookup(tt.input) + if got != tt.want { + t.Errorf("normalizeForLookup(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func TestIsSessionCookieName(t *testing.T) { + tests := []struct { + name string + want bool + }{ + {"session_id", true}, + {"JSESSIONID", true}, + {"PHPSESSID", true}, + {"connect.sid", true}, + {"auth_token", true}, + {"theme", false}, + {"lang", false}, + {"preferences", false}, + } + for _, tt := range tests { + got := isSessionCookieName(tt.name) + if got != tt.want { + t.Errorf("isSessionCookieName(%q) = %v, want %v", tt.name, got, tt.want) + } + } +} diff --git a/internal/check/check_more_test.go b/internal/check/check_more_test.go new file mode 100644 index 0000000..ab15d55 --- /dev/null +++ b/internal/check/check_more_test.go @@ -0,0 +1,972 @@ +package check + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "github.com/GrayCodeAI/inspect/internal/crawler" +) + +// This file was split out of check_test.go for readability (mechanical move; no behavior change). + +// --- Registry Additional Tests --- + +func TestRegistry_Register(t *testing.T) { + r := &Registry{checks: make(map[string]Checker)} + r.Register(&SecurityCheck{}) + r.Register(&LinksCheck{}) + + all := r.All() + if len(all) != 2 { + t.Errorf("expected 2 checks, got %d", len(all)) + } +} + +func TestRegistry_FilterEmpty(t *testing.T) { + r := DefaultRegistry() + // Empty filter should return all checks + all := r.Filter(nil) + if len(all) != 9 { + t.Errorf("expected 9 checks for empty filter, got %d", len(all)) + } +} + +func TestRegistry_RegisterOverwrite(t *testing.T) { + r := &Registry{checks: make(map[string]Checker)} + r.Register(&LinksCheck{}) + r.Register(&LinksCheck{AcceptedStatusCodes: []int{200, 201}}) + + all := r.All() + if len(all) != 1 { + t.Errorf("expected 1 check after overwrite, got %d", len(all)) + } +} + +// --- Links Check Tests --- + +func TestLinksCheck_Name(t *testing.T) { + chk := &LinksCheck{} + if chk.Name() != "links" { + t.Errorf("expected 'links', got %q", chk.Name()) + } +} + +func TestLinksCheck_IsAcceptedStatus(t *testing.T) { + // Default range (200-399) + chk := &LinksCheck{} + if !chk.isAcceptedStatus(200) { + t.Error("200 should be accepted by default") + } + if !chk.isAcceptedStatus(301) { + t.Error("301 should be accepted by default") + } + if chk.isAcceptedStatus(404) { + t.Error("404 should not be accepted by default") + } + if chk.isAcceptedStatus(500) { + t.Error("500 should not be accepted by default") + } + + // Custom status codes + chk2 := &LinksCheck{AcceptedStatusCodes: []int{200, 201, 404}} + if !chk2.isAcceptedStatus(200) { + t.Error("200 should be accepted in custom list") + } + if !chk2.isAcceptedStatus(404) { + t.Error("404 should be accepted in custom list") + } + if chk2.isAcceptedStatus(500) { + t.Error("500 should not be accepted in custom list") + } +} + +func TestLinksCheck_BrokenInternalPage(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `broken`) + brokenPage := makePage("https://example.com/broken", 404, map[string]string{}, "") + brokenPage.ParentURL = "https://example.com" + + chk := &LinksCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page, brokenPage}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "HTTP 404") { + found = true + if f.Severity != SeverityHigh { + t.Errorf("expected high severity for 404, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected finding for broken 404 page") + } +} + +func TestLinksCheck_ServerError(t *testing.T) { + page := makePage("https://example.com/error", 500, map[string]string{}, "") + page.ParentURL = "https://example.com" + + chk := &LinksCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "HTTP 500") { + found = true + if f.Severity != SeverityCritical { + t.Errorf("expected critical severity for 500, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected finding for 500 error page") + } +} + +func TestLinksCheck_FragmentValidation(t *testing.T) { + target := makePage("https://example.com/target", 200, map[string]string{"Content-Type": "text/html"}, + `
Hello
`) + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ` + Good + Bad + `) + page.Links = []crawler.Link{ + {Href: "https://example.com/target#exists", Tag: "a"}, + {Href: "https://example.com/target#missing", Tag: "a"}, + } + + chk := &LinksCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page, target}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Fragment #missing not found") { + found = true + } + } + if !found { + t.Error("expected finding for missing fragment #missing") + } +} + +func TestSeverityForStatus(t *testing.T) { + if severityForStatus(404) != SeverityHigh { + t.Error("404 should be high") + } + if severityForStatus(500) != SeverityCritical { + t.Error("500 should be critical") + } + if severityForStatus(403) != SeverityMedium { + t.Error("403 should be medium") + } + if severityForStatus(301) != SeverityLow { + t.Error("301 should be low") + } +} + +func TestExtractElementIDs(t *testing.T) { + page := &crawler.Page{ + URL: "https://example.com", + Body: []byte(`

Hello

No id
`), + } + ids := extractElementIDs(page) + if !ids["main"] { + t.Error("should find id=main") + } + if !ids["intro"] { + t.Error("should find id=intro") + } + if len(ids) != 2 { + t.Errorf("expected 2 IDs, got %d", len(ids)) + } +} + +func TestExtractFragment(t *testing.T) { + if extractFragment("/page#section") != "section" { + t.Error("expected 'section'") + } + if extractFragment("/page") != "" { + t.Error("expected empty string for no fragment") + } + if extractFragment("") != "" { + t.Error("expected empty string for empty href") + } +} + +// --- Performance Check Tests --- + +func TestPerfCheck_Name(t *testing.T) { + chk := &PerfCheck{} + if chk.Name() != "perf" { + t.Errorf("expected 'perf', got %q", chk.Name()) + } +} + +func TestPerfCheck_MissingCompression(t *testing.T) { + body := `T` + string(make([]byte, 2000)) + `` + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + }, body) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "not compressed") { + found = true + } + } + if !found { + t.Error("expected finding for missing compression") + } +} + +func TestPerfCheck_WithCompression(t *testing.T) { + body := `T` + string(make([]byte, 2000)) + `` + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "Content-Encoding": "gzip", + }, body) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + for _, f := range findings { + if strings.Contains(f.Message, "not compressed") { + t.Error("should not flag compression when Content-Encoding is set") + } + } +} + +func TestPerfCheck_MissingCacheControl(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + }, `Ttest`) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Missing Cache-Control") { + found = true + } + } + if !found { + t.Error("expected finding for missing Cache-Control") + } +} + +func TestPerfCheck_RenderBlockingScript(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Render-blocking script") { + found = true + } + } + if !found { + t.Error("expected finding for render-blocking script in head") + } +} + +func TestPerfCheck_AsyncScriptOK(t *testing.T) { + // async="async" (explicit value) works with hasAttr which requires non-empty Val + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + for _, f := range findings { + if strings.Contains(f.Message, "Render-blocking script") { + t.Error("should not flag async script") + } + } +} + +func TestPerfCheck_DeferScriptOK(t *testing.T) { + // defer="defer" (explicit value) works with hasAttr which requires non-empty Val + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + for _, f := range findings { + if strings.Contains(f.Message, "Render-blocking script") { + t.Error("should not flag deferred script") + } + } +} + +func TestPerfCheck_ImageMissingDimensions(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `T`) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "missing width/height") { + found = true + } + } + if !found { + t.Error("expected finding for image missing dimensions") + } +} + +func TestPerfCheck_ImageMissingLazy(t *testing.T) { + // Place 3 above-fold images first, then the 4th image should be flagged for missing lazy loading + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `T`+ + ``+ + ``+ + ``+ + ``+ + ``) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, `loading="lazy"`) { + found = true + } + } + if !found { + t.Error("expected finding for image missing loading=lazy") + } +} + +func TestPerfCheck_RenderBlockingStylesheet(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Render-blocking stylesheet") { + found = true + } + } + if !found { + t.Error("expected finding for render-blocking stylesheet") + } +} + +func TestPerfCheck_SkipsErrorPages(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "test") + page.Error = context.Canceled + + chk := &PerfCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should skip error pages") + } +} + +func TestFormatBytes(t *testing.T) { + tests := []struct { + input int + expected string + }{ + {500, "500 B"}, + {1024, "1.0 KB"}, + {1048576, "1.0 MB"}, + } + for _, tt := range tests { + got := formatBytes(tt.input) + if got != tt.expected { + t.Errorf("formatBytes(%d) = %q, want %q", tt.input, got, tt.expected) + } + } +} + +// --- SEO Check Tests --- + +func TestSEOCheck_Name(t *testing.T) { + chk := &SEOCheck{} + if chk.Name() != "seo" { + t.Errorf("expected 'seo', got %q", chk.Name()) + } +} + +func TestSEOCheck_MissingAllMeta(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `No meta`) + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + messages := map[string]bool{} + for _, f := range findings { + messages[f.Message] = true + } + + if !messages["Page missing tag"] { + t.Error("expected missing title finding") + } + if !messages["Missing meta description"] { + t.Error("expected missing description finding") + } + if !messages["Missing canonical URL"] { + t.Error("expected missing canonical finding") + } + if !messages["Missing viewport meta tag"] { + t.Error("expected missing viewport finding") + } + if !messages["Missing charset declaration"] { + t.Error("expected missing charset finding") + } +} + +func TestSEOCheck_AllMetaPresent(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `<html><head> + <meta charset="UTF-8"> + <title>My Page + + + + + + Content`) + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + for _, f := range findings { + if f.Severity > SeverityLow { + t.Errorf("unexpected issue above low severity: %s", f.Message) + } + } +} + +func TestSEOCheck_TitleTooLong(t *testing.T) { + longTitle := string(make([]byte, 80)) + for i := range longTitle { + longTitle = longTitle[:i] + "a" + longTitle[i+1:] + } + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``+longTitle+` + + + + `) + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Title too long") { + found = true + } + } + if !found { + t.Error("expected finding for title too long") + } +} + +func TestSEOCheck_DescriptionTooLong(t *testing.T) { + longDesc := string(make([]byte, 200)) + for i := range longDesc { + longDesc = longDesc[:i] + "a" + longDesc[i+1:] + } + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `OK + + + + + `) + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Meta description too long") { + found = true + } + } + if !found { + t.Error("expected finding for description too long") + } +} + +func TestSEOCheck_DuplicateTitle(t *testing.T) { + page1 := makePage("https://example.com/page1", 200, map[string]string{"Content-Type": "text/html"}, + `Same TitleP1`) + page2 := makePage("https://example.com/page2", 200, map[string]string{"Content-Type": "text/html"}, + `Same TitleP2`) + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page1, page2}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Duplicate title") { + found = true + } + } + if !found { + t.Error("expected finding for duplicate titles") + } +} + +func TestSEOCheck_SkipsErrorPages(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "test") + page.Error = context.Canceled + + chk := &SEOCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should skip error pages") + } +} + +// --- Forms Check Additional Tests --- + +func TestFormsCheck_Name(t *testing.T) { + chk := &FormsCheck{} + if chk.Name() != "forms" { + t.Errorf("expected 'forms', got %q", chk.Name()) + } +} + +func TestFormsCheck_MissingAction(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + page.Forms = []crawler.Form{ + {Action: "", Method: "POST"}, + } + + chk := &FormsCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "no action attribute") { + found = true + } + } + if !found { + t.Error("expected finding for missing form action") + } +} + +func TestFormsCheck_HTTPSPageHTTPAction(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + page.Forms = []crawler.Form{ + {Action: "http://example.com/submit", Method: "POST", HasCSRF: true}, + } + + chk := &FormsCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "submits form to HTTP endpoint") { + found = true + } + } + if !found { + t.Error("expected finding for HTTPS page with HTTP form action") + } +} + +func TestFormsCheck_AutocompleteIssue(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + page.Forms = []crawler.Form{ + { + Action: "/login", + Method: "POST", + HasCSRF: true, + Inputs: []crawler.FormInput{ + {Name: "password", Type: "password"}, + }, + }, + } + + chk := &FormsCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "autocomplete") { + found = true + } + } + if !found { + t.Error("expected finding for sensitive field missing autocomplete=off") + } +} + +func TestFormsCheck_FormWithID(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + page.Forms = []crawler.Form{ + {Action: "/submit", Method: "POST", ID: "login-form", HasCSRF: false}, + } + + chk := &FormsCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Element, "form#login-form") { + found = true + } + } + if !found { + t.Error("expected finding element to reference form by ID") + } +} + +func TestFormsCheck_SkipsErrorPages(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + page.Forms = []crawler.Form{{Action: "/submit", Method: "POST"}} + page.Error = context.Canceled + + chk := &FormsCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should skip error pages") + } +} + +func TestHasPasswordField(t *testing.T) { + if hasPasswordField(nil) { + t.Error("nil should return false") + } + if hasPasswordField([]crawler.FormInput{{Name: "email", Type: "email"}}) { + t.Error("email field should not be password") + } + if !hasPasswordField([]crawler.FormInput{{Name: "pass", Type: "password"}}) { + t.Error("password field should be detected") + } +} + +func TestHasAutocompleteIssue(t *testing.T) { + if hasAutocompleteIssue(nil) { + t.Error("nil should return false") + } + if !hasAutocompleteIssue([]crawler.FormInput{{Name: "pass", Type: "password"}}) { + t.Error("password type should be flagged") + } + if !hasAutocompleteIssue([]crawler.FormInput{{Name: "credit_card_number", Type: "text"}}) { + t.Error("credit card name should be flagged") + } + if hasAutocompleteIssue([]crawler.FormInput{{Name: "email", Type: "email"}}) { + t.Error("email should not be flagged") + } +} + +// --- Security Check Additional Tests --- + +func TestSecurityCheck_MixedContent(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Mixed content") { + found = true + if f.Severity != SeverityHigh { + t.Errorf("expected high severity for mixed content, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected mixed content finding") + } +} + +func TestSecurityCheck_ServerVersionExposed(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "Server": "Apache/2.4.51", + }, `test`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Server header exposes version") { + found = true + } + } + if !found { + t.Error("expected finding for server version exposure") + } +} + +func TestSecurityCheck_XPoweredBy(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + "X-Powered-By": "PHP/8.1", + }, `test`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "X-Powered-By header exposes technology") { + found = true + } + } + if !found { + t.Error("expected finding for X-Powered-By header") + } +} + +func TestSecurityCheck_CookieMissingSecure(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + }, `test`) + page.Headers.Set("Set-Cookie", "session_id=abc123; HttpOnly; SameSite=Lax") + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "missing Secure flag") { + found = true + } + } + if !found { + t.Error("expected finding for cookie missing Secure flag on HTTPS") + } +} + +func TestSecurityCheck_SessionCookieMissingHttpOnly(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + }, `test`) + page.Headers.Set("Set-Cookie", "session=abc123; Secure; SameSite=Lax") + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "missing HttpOnly flag") { + found = true + } + } + if !found { + t.Error("expected finding for session cookie missing HttpOnly") + } +} + +func TestSecurityCheck_CookieMissingSameSite(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + }, `test`) + page.Headers.Set("Set-Cookie", "pref=dark; Secure; HttpOnly") + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "missing SameSite") { + found = true + } + } + if !found { + t.Error("expected finding for cookie missing SameSite") + } +} + +func TestSecurityCheck_SameSiteNoneWithoutSecure(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{ + "Content-Type": "text/html", + }, `test`) + page.Headers.Set("Set-Cookie", "tracker=xyz; SameSite=None; HttpOnly") + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "SameSite=None without Secure") { + found = true + } + } + if !found { + t.Error("expected finding for SameSite=None without Secure") + } +} + +func TestSecurityCheck_HTTPPage(t *testing.T) { + page := makePage("http://example.com", 200, map[string]string{ + "Content-Type": "text/html", + }, `test`) + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "HTTP instead of HTTPS") { + found = true + } + } + if !found { + t.Error("expected finding for page served over HTTP") + } +} + +func TestSecurityCheck_SkipsErrorPages(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "test") + page.Error = context.Canceled + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should skip error pages") + } +} + +func TestSecurityCheck_Skips4xx(t *testing.T) { + page := makePage("https://example.com", 404, map[string]string{"Content-Type": "text/html"}, "not found") + + chk := &SecurityCheck{} + findings := chk.Run(context.Background(), []*crawler.Page{page}) + + if len(findings) != 0 { + t.Error("should skip pages with 4xx status") + } +} + +// ===================================================================== +// Finding struct tests +// ===================================================================== + +func TestFinding_SeverityLevels(t *testing.T) { + tests := []struct { + name string + severity Severity + want int + }{ + {"info", SeverityInfo, 0}, + {"low", SeverityLow, 1}, + {"medium", SeverityMedium, 2}, + {"high", SeverityHigh, 3}, + {"critical", SeverityCritical, 4}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if int(tt.severity) != tt.want { + t.Errorf("Severity %s = %d, want %d", tt.name, int(tt.severity), tt.want) + } + }) + } + + // Severity ordering + if SeverityInfo >= SeverityLow { + t.Error("SeverityInfo should be less than SeverityLow") + } + if SeverityLow >= SeverityMedium { + t.Error("SeverityLow should be less than SeverityMedium") + } + if SeverityMedium >= SeverityHigh { + t.Error("SeverityMedium should be less than SeverityHigh") + } + if SeverityHigh >= SeverityCritical { + t.Error("SeverityHigh should be less than SeverityCritical") + } +} + +func TestFinding_JSONMarshaling(t *testing.T) { + f := Finding{ + Severity: SeverityHigh, + URL: "https://example.com", + Element: "", + Message: "Image missing alt attribute", + Fix: "Add descriptive alt text", + Evidence: "src=photo.jpg", + } + + data, err := json.Marshal(f) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var roundtrip Finding + if err := json.Unmarshal(data, &roundtrip); err != nil { + t.Fatalf("json.Unmarshal failed: %v", err) + } + + if roundtrip.Severity != f.Severity { + t.Errorf("severity: got %v, want %v", roundtrip.Severity, f.Severity) + } + if roundtrip.URL != f.URL { + t.Errorf("URL: got %q, want %q", roundtrip.URL, f.URL) + } + if roundtrip.Message != f.Message { + t.Errorf("Message: got %q, want %q", roundtrip.Message, f.Message) + } + if roundtrip.Fix != f.Fix { + t.Errorf("Fix: got %q, want %q", roundtrip.Fix, f.Fix) + } + if roundtrip.Evidence != f.Evidence { + t.Errorf("Evidence: got %q, want %q", roundtrip.Evidence, f.Evidence) + } +} + +func TestFinding_JSONMarshaling_OmitsEmpty(t *testing.T) { + f := Finding{ + Severity: SeverityInfo, + URL: "https://example.com", + Message: "test", + } + + data, err := json.Marshal(f) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + + var m map[string]interface{} + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("json.Unmarshal into map failed: %v", err) + } + + // Element, Fix, Evidence should be empty strings (still present in JSON but empty) + // Note: Finding has no JSON tags, so keys are capitalized + if m["Element"] != "" { + t.Errorf("expected empty Element, got %v", m["Element"]) + } + if m["Fix"] != "" { + t.Errorf("expected empty Fix, got %v", m["Fix"]) + } + if m["Evidence"] != "" { + t.Errorf("expected empty Evidence, got %v", m["Evidence"]) + } +} diff --git a/internal/check/check_test.go b/internal/check/check_test.go index eb7c107..f40594b 100644 --- a/internal/check/check_test.go +++ b/internal/check/check_test.go @@ -2,7 +2,6 @@ package check import ( "context" - "encoding/json" "net/http" "strings" "testing" @@ -1127,1871 +1126,5 @@ func TestIsFocusable(t *testing.T) { } } -// --- Registry Additional Tests --- - -func TestRegistry_Register(t *testing.T) { - r := &Registry{checks: make(map[string]Checker)} - r.Register(&SecurityCheck{}) - r.Register(&LinksCheck{}) - - all := r.All() - if len(all) != 2 { - t.Errorf("expected 2 checks, got %d", len(all)) - } -} - -func TestRegistry_FilterEmpty(t *testing.T) { - r := DefaultRegistry() - // Empty filter should return all checks - all := r.Filter(nil) - if len(all) != 9 { - t.Errorf("expected 9 checks for empty filter, got %d", len(all)) - } -} - -func TestRegistry_RegisterOverwrite(t *testing.T) { - r := &Registry{checks: make(map[string]Checker)} - r.Register(&LinksCheck{}) - r.Register(&LinksCheck{AcceptedStatusCodes: []int{200, 201}}) - - all := r.All() - if len(all) != 1 { - t.Errorf("expected 1 check after overwrite, got %d", len(all)) - } -} - -// --- Links Check Tests --- - -func TestLinksCheck_Name(t *testing.T) { - chk := &LinksCheck{} - if chk.Name() != "links" { - t.Errorf("expected 'links', got %q", chk.Name()) - } -} - -func TestLinksCheck_IsAcceptedStatus(t *testing.T) { - // Default range (200-399) - chk := &LinksCheck{} - if !chk.isAcceptedStatus(200) { - t.Error("200 should be accepted by default") - } - if !chk.isAcceptedStatus(301) { - t.Error("301 should be accepted by default") - } - if chk.isAcceptedStatus(404) { - t.Error("404 should not be accepted by default") - } - if chk.isAcceptedStatus(500) { - t.Error("500 should not be accepted by default") - } - - // Custom status codes - chk2 := &LinksCheck{AcceptedStatusCodes: []int{200, 201, 404}} - if !chk2.isAcceptedStatus(200) { - t.Error("200 should be accepted in custom list") - } - if !chk2.isAcceptedStatus(404) { - t.Error("404 should be accepted in custom list") - } - if chk2.isAcceptedStatus(500) { - t.Error("500 should not be accepted in custom list") - } -} - -func TestLinksCheck_BrokenInternalPage(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `broken`) - brokenPage := makePage("https://example.com/broken", 404, map[string]string{}, "") - brokenPage.ParentURL = "https://example.com" - - chk := &LinksCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page, brokenPage}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "HTTP 404") { - found = true - if f.Severity != SeverityHigh { - t.Errorf("expected high severity for 404, got %v", f.Severity) - } - } - } - if !found { - t.Error("expected finding for broken 404 page") - } -} - -func TestLinksCheck_ServerError(t *testing.T) { - page := makePage("https://example.com/error", 500, map[string]string{}, "") - page.ParentURL = "https://example.com" - - chk := &LinksCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "HTTP 500") { - found = true - if f.Severity != SeverityCritical { - t.Errorf("expected critical severity for 500, got %v", f.Severity) - } - } - } - if !found { - t.Error("expected finding for 500 error page") - } -} - -func TestLinksCheck_FragmentValidation(t *testing.T) { - target := makePage("https://example.com/target", 200, map[string]string{"Content-Type": "text/html"}, - `
Hello
`) - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ` - Good - Bad - `) - page.Links = []crawler.Link{ - {Href: "https://example.com/target#exists", Tag: "a"}, - {Href: "https://example.com/target#missing", Tag: "a"}, - } - - chk := &LinksCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page, target}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Fragment #missing not found") { - found = true - } - } - if !found { - t.Error("expected finding for missing fragment #missing") - } -} - -func TestSeverityForStatus(t *testing.T) { - if severityForStatus(404) != SeverityHigh { - t.Error("404 should be high") - } - if severityForStatus(500) != SeverityCritical { - t.Error("500 should be critical") - } - if severityForStatus(403) != SeverityMedium { - t.Error("403 should be medium") - } - if severityForStatus(301) != SeverityLow { - t.Error("301 should be low") - } -} - -func TestExtractElementIDs(t *testing.T) { - page := &crawler.Page{ - URL: "https://example.com", - Body: []byte(`

Hello

No id
`), - } - ids := extractElementIDs(page) - if !ids["main"] { - t.Error("should find id=main") - } - if !ids["intro"] { - t.Error("should find id=intro") - } - if len(ids) != 2 { - t.Errorf("expected 2 IDs, got %d", len(ids)) - } -} - -func TestExtractFragment(t *testing.T) { - if extractFragment("/page#section") != "section" { - t.Error("expected 'section'") - } - if extractFragment("/page") != "" { - t.Error("expected empty string for no fragment") - } - if extractFragment("") != "" { - t.Error("expected empty string for empty href") - } -} - -// --- Performance Check Tests --- - -func TestPerfCheck_Name(t *testing.T) { - chk := &PerfCheck{} - if chk.Name() != "perf" { - t.Errorf("expected 'perf', got %q", chk.Name()) - } -} - -func TestPerfCheck_MissingCompression(t *testing.T) { - body := `T` + string(make([]byte, 2000)) + `` - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - }, body) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "not compressed") { - found = true - } - } - if !found { - t.Error("expected finding for missing compression") - } -} - -func TestPerfCheck_WithCompression(t *testing.T) { - body := `T` + string(make([]byte, 2000)) + `` - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "Content-Encoding": "gzip", - }, body) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - for _, f := range findings { - if strings.Contains(f.Message, "not compressed") { - t.Error("should not flag compression when Content-Encoding is set") - } - } -} - -func TestPerfCheck_MissingCacheControl(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - }, `Ttest`) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Missing Cache-Control") { - found = true - } - } - if !found { - t.Error("expected finding for missing Cache-Control") - } -} - -func TestPerfCheck_RenderBlockingScript(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Render-blocking script") { - found = true - } - } - if !found { - t.Error("expected finding for render-blocking script in head") - } -} - -func TestPerfCheck_AsyncScriptOK(t *testing.T) { - // async="async" (explicit value) works with hasAttr which requires non-empty Val - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - for _, f := range findings { - if strings.Contains(f.Message, "Render-blocking script") { - t.Error("should not flag async script") - } - } -} - -func TestPerfCheck_DeferScriptOK(t *testing.T) { - // defer="defer" (explicit value) works with hasAttr which requires non-empty Val - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - for _, f := range findings { - if strings.Contains(f.Message, "Render-blocking script") { - t.Error("should not flag deferred script") - } - } -} - -func TestPerfCheck_ImageMissingDimensions(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `T`) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "missing width/height") { - found = true - } - } - if !found { - t.Error("expected finding for image missing dimensions") - } -} - -func TestPerfCheck_ImageMissingLazy(t *testing.T) { - // Place 3 above-fold images first, then the 4th image should be flagged for missing lazy loading - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `T`+ - ``+ - ``+ - ``+ - ``+ - ``) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, `loading="lazy"`) { - found = true - } - } - if !found { - t.Error("expected finding for image missing loading=lazy") - } -} - -func TestPerfCheck_RenderBlockingStylesheet(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Render-blocking stylesheet") { - found = true - } - } - if !found { - t.Error("expected finding for render-blocking stylesheet") - } -} - -func TestPerfCheck_SkipsErrorPages(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "test") - page.Error = context.Canceled - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should skip error pages") - } -} - -func TestFormatBytes(t *testing.T) { - tests := []struct { - input int - expected string - }{ - {500, "500 B"}, - {1024, "1.0 KB"}, - {1048576, "1.0 MB"}, - } - for _, tt := range tests { - got := formatBytes(tt.input) - if got != tt.expected { - t.Errorf("formatBytes(%d) = %q, want %q", tt.input, got, tt.expected) - } - } -} - -// --- SEO Check Tests --- - -func TestSEOCheck_Name(t *testing.T) { - chk := &SEOCheck{} - if chk.Name() != "seo" { - t.Errorf("expected 'seo', got %q", chk.Name()) - } -} - -func TestSEOCheck_MissingAllMeta(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `No meta`) - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - messages := map[string]bool{} - for _, f := range findings { - messages[f.Message] = true - } - - if !messages["Page missing tag"] { - t.Error("expected missing title finding") - } - if !messages["Missing meta description"] { - t.Error("expected missing description finding") - } - if !messages["Missing canonical URL"] { - t.Error("expected missing canonical finding") - } - if !messages["Missing viewport meta tag"] { - t.Error("expected missing viewport finding") - } - if !messages["Missing charset declaration"] { - t.Error("expected missing charset finding") - } -} - -func TestSEOCheck_AllMetaPresent(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `<html><head> - <meta charset="UTF-8"> - <title>My Page - - - - - - Content`) - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - for _, f := range findings { - if f.Severity > SeverityLow { - t.Errorf("unexpected issue above low severity: %s", f.Message) - } - } -} - -func TestSEOCheck_TitleTooLong(t *testing.T) { - longTitle := string(make([]byte, 80)) - for i := range longTitle { - longTitle = longTitle[:i] + "a" + longTitle[i+1:] - } - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``+longTitle+` - - - - `) - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Title too long") { - found = true - } - } - if !found { - t.Error("expected finding for title too long") - } -} - -func TestSEOCheck_DescriptionTooLong(t *testing.T) { - longDesc := string(make([]byte, 200)) - for i := range longDesc { - longDesc = longDesc[:i] + "a" + longDesc[i+1:] - } - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `OK - - - - - `) - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Meta description too long") { - found = true - } - } - if !found { - t.Error("expected finding for description too long") - } -} - -func TestSEOCheck_DuplicateTitle(t *testing.T) { - page1 := makePage("https://example.com/page1", 200, map[string]string{"Content-Type": "text/html"}, - `Same TitleP1`) - page2 := makePage("https://example.com/page2", 200, map[string]string{"Content-Type": "text/html"}, - `Same TitleP2`) - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page1, page2}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Duplicate title") { - found = true - } - } - if !found { - t.Error("expected finding for duplicate titles") - } -} - -func TestSEOCheck_SkipsErrorPages(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "test") - page.Error = context.Canceled - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should skip error pages") - } -} - -// --- Forms Check Additional Tests --- - -func TestFormsCheck_Name(t *testing.T) { - chk := &FormsCheck{} - if chk.Name() != "forms" { - t.Errorf("expected 'forms', got %q", chk.Name()) - } -} - -func TestFormsCheck_MissingAction(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - page.Forms = []crawler.Form{ - {Action: "", Method: "POST"}, - } - - chk := &FormsCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "no action attribute") { - found = true - } - } - if !found { - t.Error("expected finding for missing form action") - } -} - -func TestFormsCheck_HTTPSPageHTTPAction(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - page.Forms = []crawler.Form{ - {Action: "http://example.com/submit", Method: "POST", HasCSRF: true}, - } - - chk := &FormsCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "submits form to HTTP endpoint") { - found = true - } - } - if !found { - t.Error("expected finding for HTTPS page with HTTP form action") - } -} - -func TestFormsCheck_AutocompleteIssue(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - page.Forms = []crawler.Form{ - { - Action: "/login", - Method: "POST", - HasCSRF: true, - Inputs: []crawler.FormInput{ - {Name: "password", Type: "password"}, - }, - }, - } - - chk := &FormsCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "autocomplete") { - found = true - } - } - if !found { - t.Error("expected finding for sensitive field missing autocomplete=off") - } -} - -func TestFormsCheck_FormWithID(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - page.Forms = []crawler.Form{ - {Action: "/submit", Method: "POST", ID: "login-form", HasCSRF: false}, - } - - chk := &FormsCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Element, "form#login-form") { - found = true - } - } - if !found { - t.Error("expected finding element to reference form by ID") - } -} - -func TestFormsCheck_SkipsErrorPages(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - page.Forms = []crawler.Form{{Action: "/submit", Method: "POST"}} - page.Error = context.Canceled - - chk := &FormsCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should skip error pages") - } -} - -func TestHasPasswordField(t *testing.T) { - if hasPasswordField(nil) { - t.Error("nil should return false") - } - if hasPasswordField([]crawler.FormInput{{Name: "email", Type: "email"}}) { - t.Error("email field should not be password") - } - if !hasPasswordField([]crawler.FormInput{{Name: "pass", Type: "password"}}) { - t.Error("password field should be detected") - } -} - -func TestHasAutocompleteIssue(t *testing.T) { - if hasAutocompleteIssue(nil) { - t.Error("nil should return false") - } - if !hasAutocompleteIssue([]crawler.FormInput{{Name: "pass", Type: "password"}}) { - t.Error("password type should be flagged") - } - if !hasAutocompleteIssue([]crawler.FormInput{{Name: "credit_card_number", Type: "text"}}) { - t.Error("credit card name should be flagged") - } - if hasAutocompleteIssue([]crawler.FormInput{{Name: "email", Type: "email"}}) { - t.Error("email should not be flagged") - } -} - -// --- Security Check Additional Tests --- - -func TestSecurityCheck_MixedContent(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Mixed content") { - found = true - if f.Severity != SeverityHigh { - t.Errorf("expected high severity for mixed content, got %v", f.Severity) - } - } - } - if !found { - t.Error("expected mixed content finding") - } -} - -func TestSecurityCheck_ServerVersionExposed(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "Server": "Apache/2.4.51", - }, `test`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Server header exposes version") { - found = true - } - } - if !found { - t.Error("expected finding for server version exposure") - } -} - -func TestSecurityCheck_XPoweredBy(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "X-Powered-By": "PHP/8.1", - }, `test`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "X-Powered-By header exposes technology") { - found = true - } - } - if !found { - t.Error("expected finding for X-Powered-By header") - } -} - -func TestSecurityCheck_CookieMissingSecure(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - }, `test`) - page.Headers.Set("Set-Cookie", "session_id=abc123; HttpOnly; SameSite=Lax") - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "missing Secure flag") { - found = true - } - } - if !found { - t.Error("expected finding for cookie missing Secure flag on HTTPS") - } -} - -func TestSecurityCheck_SessionCookieMissingHttpOnly(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - }, `test`) - page.Headers.Set("Set-Cookie", "session=abc123; Secure; SameSite=Lax") - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "missing HttpOnly flag") { - found = true - } - } - if !found { - t.Error("expected finding for session cookie missing HttpOnly") - } -} - -func TestSecurityCheck_CookieMissingSameSite(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - }, `test`) - page.Headers.Set("Set-Cookie", "pref=dark; Secure; HttpOnly") - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "missing SameSite") { - found = true - } - } - if !found { - t.Error("expected finding for cookie missing SameSite") - } -} - -func TestSecurityCheck_SameSiteNoneWithoutSecure(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - }, `test`) - page.Headers.Set("Set-Cookie", "tracker=xyz; SameSite=None; HttpOnly") - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "SameSite=None without Secure") { - found = true - } - } - if !found { - t.Error("expected finding for SameSite=None without Secure") - } -} - -func TestSecurityCheck_HTTPPage(t *testing.T) { - page := makePage("http://example.com", 200, map[string]string{ - "Content-Type": "text/html", - }, `test`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "HTTP instead of HTTPS") { - found = true - } - } - if !found { - t.Error("expected finding for page served over HTTP") - } -} - -func TestSecurityCheck_SkipsErrorPages(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "test") - page.Error = context.Canceled - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should skip error pages") - } -} - -func TestSecurityCheck_Skips4xx(t *testing.T) { - page := makePage("https://example.com", 404, map[string]string{"Content-Type": "text/html"}, "not found") - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should skip pages with 4xx status") - } -} - -// ===================================================================== -// Finding struct tests -// ===================================================================== - -func TestFinding_SeverityLevels(t *testing.T) { - tests := []struct { - name string - severity Severity - want int - }{ - {"info", SeverityInfo, 0}, - {"low", SeverityLow, 1}, - {"medium", SeverityMedium, 2}, - {"high", SeverityHigh, 3}, - {"critical", SeverityCritical, 4}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if int(tt.severity) != tt.want { - t.Errorf("Severity %s = %d, want %d", tt.name, int(tt.severity), tt.want) - } - }) - } - - // Severity ordering - if SeverityInfo >= SeverityLow { - t.Error("SeverityInfo should be less than SeverityLow") - } - if SeverityLow >= SeverityMedium { - t.Error("SeverityLow should be less than SeverityMedium") - } - if SeverityMedium >= SeverityHigh { - t.Error("SeverityMedium should be less than SeverityHigh") - } - if SeverityHigh >= SeverityCritical { - t.Error("SeverityHigh should be less than SeverityCritical") - } -} - -func TestFinding_JSONMarshaling(t *testing.T) { - f := Finding{ - Severity: SeverityHigh, - URL: "https://example.com", - Element: "", - Message: "Image missing alt attribute", - Fix: "Add descriptive alt text", - Evidence: "src=photo.jpg", - } - - data, err := json.Marshal(f) - if err != nil { - t.Fatalf("json.Marshal failed: %v", err) - } - - var roundtrip Finding - if err := json.Unmarshal(data, &roundtrip); err != nil { - t.Fatalf("json.Unmarshal failed: %v", err) - } - - if roundtrip.Severity != f.Severity { - t.Errorf("severity: got %v, want %v", roundtrip.Severity, f.Severity) - } - if roundtrip.URL != f.URL { - t.Errorf("URL: got %q, want %q", roundtrip.URL, f.URL) - } - if roundtrip.Message != f.Message { - t.Errorf("Message: got %q, want %q", roundtrip.Message, f.Message) - } - if roundtrip.Fix != f.Fix { - t.Errorf("Fix: got %q, want %q", roundtrip.Fix, f.Fix) - } - if roundtrip.Evidence != f.Evidence { - t.Errorf("Evidence: got %q, want %q", roundtrip.Evidence, f.Evidence) - } -} - -func TestFinding_JSONMarshaling_OmitsEmpty(t *testing.T) { - f := Finding{ - Severity: SeverityInfo, - URL: "https://example.com", - Message: "test", - } - - data, err := json.Marshal(f) - if err != nil { - t.Fatalf("json.Marshal failed: %v", err) - } - - var m map[string]interface{} - if err := json.Unmarshal(data, &m); err != nil { - t.Fatalf("json.Unmarshal into map failed: %v", err) - } - - // Element, Fix, Evidence should be empty strings (still present in JSON but empty) - // Note: Finding has no JSON tags, so keys are capitalized - if m["Element"] != "" { - t.Errorf("expected empty Element, got %v", m["Element"]) - } - if m["Fix"] != "" { - t.Errorf("expected empty Fix, got %v", m["Fix"]) - } - if m["Evidence"] != "" { - t.Errorf("expected empty Evidence, got %v", m["Evidence"]) - } -} - -// ===================================================================== -// Checker interface compliance tests -// ===================================================================== - -func TestCheckerInterface_AllChecksImplement(t *testing.T) { - // Verify each concrete check type satisfies the Checker interface at compile time. - // This test exercises Name() and verifies the interface is implemented. - var _ Checker = &LinksCheck{} - var _ Checker = &SecurityCheck{} - var _ Checker = &FormsCheck{} - var _ Checker = &A11yCheck{} - var _ Checker = &PerfCheck{} - var _ Checker = &SEOCheck{} - var _ Checker = &SRICheck{} - var _ Checker = &AIReadyCheck{} - var _ Checker = &ReachabilityCheck{} -} - -func TestCheckerInterface_Names(t *testing.T) { - checks := map[string]Checker{ - "links": &LinksCheck{}, - "security": &SecurityCheck{}, - "forms": &FormsCheck{}, - "a11y": &A11yCheck{}, - "perf": &PerfCheck{}, - "seo": &SEOCheck{}, - "sri": &SRICheck{}, - "aiready": &AIReadyCheck{}, - "reachability": &ReachabilityCheck{}, - } - for expected, chk := range checks { - if chk.Name() != expected { - t.Errorf("%T.Name() = %q, want %q", chk, chk.Name(), expected) - } - } -} - -func TestCheckerInterface_RunReturnsSlice(t *testing.T) { - // Verify Run() returns a non-nil slice for an empty input (no panic). - ctx := context.Background() - checks := []Checker{ - &LinksCheck{}, - &SecurityCheck{}, - &FormsCheck{}, - &A11yCheck{}, - &PerfCheck{}, - &SEOCheck{}, - &SRICheck{}, - &AIReadyCheck{}, - &ReachabilityCheck{}, - } - for _, chk := range checks { - findings := chk.Run(ctx, nil) - // findings may be nil for empty input, that's OK - _ = findings - } -} - -// ===================================================================== -// Registry comprehensive tests -// ===================================================================== - -func TestDefaultRegistry_AllNames(t *testing.T) { - r := DefaultRegistry() - expectedNames := []string{ - "links", "security", "forms", "a11y", "perf", "seo", "sri", "aiready", "reachability", - } - all := r.All() - nameSet := make(map[string]bool) - for _, c := range all { - nameSet[c.Name()] = true - } - for _, name := range expectedNames { - if !nameSet[name] { - t.Errorf("DefaultRegistry missing expected check %q", name) - } - } - if len(all) != len(expectedNames) { - t.Errorf("DefaultRegistry has %d checks, expected %d", len(all), len(expectedNames)) - } -} - -func TestRegistry_Filter_MultipleSelections(t *testing.T) { - r := DefaultRegistry() - filtered := r.Filter([]string{"perf", "seo", "sri"}) - if len(filtered) != 3 { - t.Errorf("expected 3 checks, got %d", len(filtered)) - } - names := make(map[string]bool) - for _, c := range filtered { - names[c.Name()] = true - } - for _, want := range []string{"perf", "seo", "sri"} { - if !names[want] { - t.Errorf("missing expected check %q in filtered results", want) - } - } -} - -func TestRegistry_Filter_NonExistentName(t *testing.T) { - r := DefaultRegistry() - filtered := r.Filter([]string{"nonexistent"}) - if len(filtered) != 0 { - t.Errorf("expected 0 checks for nonexistent name, got %d", len(filtered)) - } -} - -func TestRegistry_Filter_Mix(t *testing.T) { - r := DefaultRegistry() - filtered := r.Filter([]string{"links", "nonexistent", "seo"}) - if len(filtered) != 2 { - t.Errorf("expected 2 checks (links + seo), got %d", len(filtered)) - } -} - -// ===================================================================== -// LinksCheck additional tests -// ===================================================================== - -func TestLinksCheck_RelativeURLs(t *testing.T) { - page := makePage("https://example.com/page", 200, map[string]string{"Content-Type": "text/html"}, - ` - About - Parent - Sibling - `) - aboutPage := makePage("https://example.com/about", 200, map[string]string{"Content-Type": "text/html"}, - `About`) - parentPage := makePage("https://example.com/parent", 200, map[string]string{"Content-Type": "text/html"}, - `Parent`) - siblingPage := makePage("https://example.com/sibling", 200, map[string]string{"Content-Type": "text/html"}, - `Sibling`) - - chk := &LinksCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page, aboutPage, parentPage, siblingPage}) - - // All pages are reachable so no broken-link findings expected - for _, f := range findings { - if strings.Contains(f.Message, "HTTP 404") || strings.Contains(f.Message, "HTTP 500") { - t.Errorf("unexpected broken link finding: %s", f.Message) - } - } -} - -func TestLinksCheck_EmptyPage(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - chk := &LinksCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - // Should not panic, may or may not have findings - _ = findings -} - -func TestLinksCheck_ErrorPage(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{}, "") - page.Error = context.Canceled - - chk := &LinksCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should skip error pages") - } -} - -func TestLinksCheck_NoPages(t *testing.T) { - chk := &LinksCheck{} - findings := chk.Run(context.Background(), nil) - if len(findings) != 0 { - t.Errorf("expected no findings for nil pages, got %d", len(findings)) - } -} - -func TestResolveLink(t *testing.T) { - tests := []struct { - base, href, want string - }{ - {"https://example.com/page", "/about", "https://example.com/about"}, - {"https://example.com/page", "https://other.com/x", "https://other.com/x"}, - {"https://example.com/page", "", ""}, - {"https://example.com/a/b", "../c", "https://example.com/c"}, - {"https://example.com/a/b", "c", "https://example.com/a/c"}, - } - for _, tt := range tests { - got := resolveLink(tt.base, tt.href) - if got != tt.want { - t.Errorf("resolveLink(%q, %q) = %q, want %q", tt.base, tt.href, got, tt.want) - } - } -} - -// ===================================================================== -// SecurityCheck additional tests -// ===================================================================== - -func TestSecurityCheck_CSPUnsafeInline(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "Content-Security-Policy": "default-src 'self'; script-src 'unsafe-inline'", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "Strict-Transport-Security": "max-age=31536000", - "Referrer-Policy": "strict-origin", - "Permissions-Policy": "camera=()", - }, `test`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "unsafe-inline") { - found = true - } - } - if !found { - t.Error("expected finding for unsafe-inline in script-src") - } -} - -func TestSecurityCheck_CSPUnsafeEval(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "Content-Security-Policy": "script-src 'unsafe-eval'", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "Strict-Transport-Security": "max-age=31536000", - "Referrer-Policy": "strict-origin", - "Permissions-Policy": "camera=()", - }, `test`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "unsafe-eval") { - found = true - } - } - if !found { - t.Error("expected finding for unsafe-eval in script-src") - } -} - -func TestSecurityCheck_CSPWildcard(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "Content-Security-Policy": "default-src *", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "Strict-Transport-Security": "max-age=31536000", - "Referrer-Policy": "strict-origin", - "Permissions-Policy": "camera=()", - }, `test`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "wildcard") { - found = true - } - } - if !found { - t.Error("expected finding for wildcard in CSP") - } -} - -func TestSecurityCheck_CSPMissingFrameAncestors(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "Content-Security-Policy": "default-src 'self'", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "Strict-Transport-Security": "max-age=31536000", - "Referrer-Policy": "strict-origin", - "Permissions-Policy": "camera=()", - }, `test`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "frame-ancestors") { - found = true - } - } - if !found { - t.Error("expected finding for missing frame-ancestors directive") - } -} - -func TestSecurityCheck_ExposedSecrets(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{ - "Content-Type": "text/html", - "Content-Security-Policy": "default-src 'self'; frame-ancestors 'self'", - "X-Content-Type-Options": "nosniff", - "X-Frame-Options": "DENY", - "Strict-Transport-Security": "max-age=31536000", - "Referrer-Policy": "strict-origin", - "Permissions-Policy": "camera=()", - }, `var api_key = "sk-abcdefghijklmnopqrstuvwxyz1234567890"`) - - chk := &SecurityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "secret") || strings.Contains(f.Message, "credential") { - found = true - if f.Severity != SeverityCritical { - t.Errorf("expected critical severity for exposed secrets, got %v", f.Severity) - } - } - } - if !found { - t.Error("expected finding for exposed API key/secret") - } -} - -func TestSecurityCheck_Name(t *testing.T) { - chk := &SecurityCheck{} - if chk.Name() != "security" { - t.Errorf("expected 'security', got %q", chk.Name()) - } -} - -// ===================================================================== -// A11yCheck additional tests -// ===================================================================== - -func TestA11yCheck_MissingLabels(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `Test
- -
`) - - chk := &A11yCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "label") { - found = true - } - } - if !found { - t.Error("expected finding for input missing label") - } -} - -func TestA11yCheck_LabelWithFor(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `Test
- - -
`) - - chk := &A11yCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - for _, f := range findings { - if strings.Contains(f.Message, "label") && strings.Contains(f.Element, "email") { - t.Errorf("should not flag input with associated label: %s", f.Message) - } - } -} - -func TestA11yCheck_MissingLang(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `Test

Content

`) - - chk := &A11yCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "lang") { - found = true - } - } - if !found { - t.Error("expected finding for missing lang attribute") - } -} - -func TestA11yCheck_EmptyLinkText(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `Test
- -
`) - - chk := &A11yCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "no accessible text") { - found = true - } - } - if !found { - t.Error("expected finding for link with no accessible text") - } -} - -func TestA11yCheck_SkipsErrorPages(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - page.Error = context.Canceled - - chk := &A11yCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - // Advanced A11y also skips error pages, so total should be 0 - if len(findings) != 0 { - t.Error("should skip error pages") - } -} - -func TestA11yCheck_EmptyBody(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - - chk := &A11yCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should not produce findings for empty body") - } -} - -// ===================================================================== -// PerfCheck additional tests -// ===================================================================== - -func TestPerfCheck_NoPages(t *testing.T) { - chk := &PerfCheck{} - findings := chk.Run(context.Background(), nil) - if len(findings) != 0 { - t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) - } -} - -func TestPerfCheck_EmptyBody(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - if len(findings) != 0 { - t.Error("should not produce findings for empty body") - } -} - -func TestPerfCheck_LargeImages(t *testing.T) { - // Image without dimensions should trigger a finding - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html", "Content-Encoding": "gzip", "Cache-Control": "max-age=3600"}, - `T`) - - chk := &PerfCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "missing width/height") { - found = true - } - } - if !found { - t.Error("expected finding for image missing dimensions") - } -} - -// ===================================================================== -// SEOCheck additional tests -// ===================================================================== - -func TestSEOCheck_MissingTitle(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "missing ") || strings.Contains(f.Message, "missing <title> tag") || strings.Contains(f.Message, "Page missing <title>") { - found = true - } - } - if !found { - t.Error("expected finding for missing title tag") - } -} - -func TestSEOCheck_MissingDescription(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `<html><head><title>Page`) - - chk := &SEOCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - found := false - for _, f := range findings { - if strings.Contains(f.Message, "Missing meta description") { - found = true - } - } - if !found { - t.Error("expected finding for missing meta description") - } -} - -func TestSEOCheck_NoPages(t *testing.T) { - chk := &SEOCheck{} - findings := chk.Run(context.Background(), nil) - if len(findings) != 0 { - t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) - } -} - -// ===================================================================== -// FormsCheck additional tests -// ===================================================================== - -func TestFormsCheck_NoPages(t *testing.T) { - chk := &FormsCheck{} - findings := chk.Run(context.Background(), nil) - if len(findings) != 0 { - t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) - } -} - -func TestFormsCheck_NoForms(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `

No forms here

`) - - chk := &FormsCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - if len(findings) != 0 { - t.Errorf("expected 0 findings for page with no forms, got %d", len(findings)) - } -} - -func TestFormsCheck_ValidForm(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - page.Forms = []crawler.Form{ - {Action: "/submit", Method: "POST", HasCSRF: true}, - } - - chk := &FormsCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - for _, f := range findings { - if strings.Contains(f.Message, "CSRF") || strings.Contains(f.Message, "no action") { - t.Errorf("unexpected finding for valid form: %s", f.Message) - } - } -} - -// ===================================================================== -// SRICheck additional tests -// ===================================================================== - -func TestSRICheck_NoPages(t *testing.T) { - chk := &SRICheck{} - findings := chk.Run(context.Background(), nil) - if len(findings) != 0 { - t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) - } -} - -// ===================================================================== -// AIReadyCheck additional tests -// ===================================================================== - -func TestAIReadyCheck_NoPages(t *testing.T) { - chk := &AIReadyCheck{} - findings := chk.Run(context.Background(), nil) - // Should return llms.txt and sitemap findings based on empty page list - if len(findings) == 0 { - t.Log("AIReadyCheck with nil pages returns findings about missing llms.txt and sitemap") - } -} - -func TestAIReadyCheck_ErrorPage(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - page.Error = context.Canceled - - chk := &AIReadyCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - // Error pages are skipped for per-page checks, but llms.txt/sitemap findings may still appear - for _, f := range findings { - if strings.Contains(f.Message, "markdown alternate") || strings.Contains(f.Message, "structured data") { - t.Errorf("should skip error page for per-page checks: %s", f.Message) - } - } -} - -// ===================================================================== -// ReachabilityCheck tests -// ===================================================================== - -func TestReachabilityCheck_Name(t *testing.T) { - chk := &ReachabilityCheck{} - if chk.Name() != "reachability" { - t.Errorf("expected 'reachability', got %q", chk.Name()) - } -} - -func TestReachabilityCheck_EmptyBody(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - - chk := &ReachabilityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should not produce findings for empty body") - } -} - -func TestReachabilityCheck_ErrorPage(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - page.Error = context.Canceled - - chk := &ReachabilityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - if len(findings) != 0 { - t.Error("should skip error pages") - } -} - -func TestReachabilityCheck_DataURI(t *testing.T) { - // data: URIs should be skipped (not checked for reachability) - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ``) - - chk := &ReachabilityCheck{} - findings := chk.Run(context.Background(), []*crawler.Page{page}) - - for _, f := range findings { - if strings.Contains(f.Message, "data:") { - t.Error("should not check reachability of data: URIs") - } - } -} - -func TestReachabilityCheck_NoPages(t *testing.T) { - chk := &ReachabilityCheck{} - findings := chk.Run(context.Background(), nil) - if len(findings) != 0 { - t.Errorf("expected 0 findings for nil pages, got %d", len(findings)) - } -} - -func TestExtractResourceRefs(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - ` - - - - - - - `) - - refs := extractResourceRefs(page) - - typeCounts := map[string]int{} - for _, ref := range refs { - typeCounts[ref.Resource]++ - } - - if typeCounts["script"] != 1 { - t.Errorf("expected 1 script ref, got %d", typeCounts["script"]) - } - if typeCounts["stylesheet"] != 1 { - t.Errorf("expected 1 stylesheet ref, got %d", typeCounts["stylesheet"]) - } - if typeCounts["image"] != 1 { - t.Errorf("expected 1 image ref, got %d", typeCounts["image"]) - } - if typeCounts["media"] != 1 { - t.Errorf("expected 1 media ref (video), got %d", typeCounts["media"]) - } - if typeCounts["iframe"] != 1 { - t.Errorf("expected 1 iframe ref, got %d", typeCounts["iframe"]) - } -} - -func TestExtractResourceRefs_EmptyBody(t *testing.T) { - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") - refs := extractResourceRefs(page) - if len(refs) != 0 { - t.Errorf("expected 0 refs for empty body, got %d", len(refs)) - } -} - -func TestResolveURL(t *testing.T) { - tests := []struct { - base, href, want string - }{ - {"https://example.com/page", "/img/photo.jpg", "https://example.com/img/photo.jpg"}, - {"https://example.com/page", "https://cdn.example.com/lib.js", "https://cdn.example.com/lib.js"}, - {"https://example.com/page", "", ""}, - {"https://example.com/a/b", "../img.jpg", "https://example.com/img.jpg"}, - } - for _, tt := range tests { - got := resolveURL(tt.base, tt.href) - if got != tt.want { - t.Errorf("resolveURL(%q, %q) = %q, want %q", tt.base, tt.href, got, tt.want) - } - } -} - -func TestSeverityForResourceStatus(t *testing.T) { - if severityForResourceStatus(404) != SeverityHigh { - t.Error("404 should be high") - } - if severityForResourceStatus(500) != SeverityCritical { - t.Error("500 should be critical") - } - if severityForResourceStatus(403) != SeverityMedium { - t.Error("403 should be medium") - } - if severityForResourceStatus(301) != SeverityLow { - t.Error("301 should be low") - } -} - -// ===================================================================== -// Edge cases: empty pages, no HTML body, error pages -// ===================================================================== - -func TestAllChecks_EmptyPagesSlice(t *testing.T) { - ctx := context.Background() - checks := []Checker{ - &LinksCheck{}, - &SecurityCheck{}, - &FormsCheck{}, - &A11yCheck{}, - &PerfCheck{}, - &SEOCheck{}, - &SRICheck{}, - &AIReadyCheck{}, - &ReachabilityCheck{}, - } - for _, chk := range checks { - // Should not panic with nil pages - findings := chk.Run(ctx, nil) - if findings == nil { - findings = []Finding{} - } - t.Logf("%s: %d findings on nil pages", chk.Name(), len(findings)) - } -} - -func TestAllChecks_ErrorPages(t *testing.T) { - ctx := context.Background() - page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `
`) - page.Error = context.Canceled - - checks := []Checker{ - &LinksCheck{}, - &SecurityCheck{}, - &FormsCheck{}, - &A11yCheck{}, - &PerfCheck{}, - &SEOCheck{}, - &SRICheck{}, - &AIReadyCheck{}, - &ReachabilityCheck{}, - } - for _, chk := range checks { - findings := chk.Run(ctx, []*crawler.Page{page}) - // Most checks should skip error pages (links may still report status) - t.Logf("%s: %d findings on error page", chk.Name(), len(findings)) - } -} - -func TestAllChecks_NonHTMLBody(t *testing.T) { - ctx := context.Background() - page := makePage("https://example.com/data.json", 200, - map[string]string{"Content-Type": "application/json"}, - `{"key": "value"}`) - - checks := []Checker{ - &SecurityCheck{}, - &FormsCheck{}, - &PerfCheck{}, - &SEOCheck{}, - &SRICheck{}, - } - for _, chk := range checks { - findings := chk.Run(ctx, []*crawler.Page{page}) - t.Logf("%s: %d findings on JSON page", chk.Name(), len(findings)) - } -} - -func TestAllChecks_MultiplePages(t *testing.T) { - ctx := context.Background() - pages := []*crawler.Page{ - makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, - `Home

Home

`), - makePage("https://example.com/about", 200, map[string]string{"Content-Type": "text/html"}, - `About

About

`), - makePage("https://example.com/contact", 200, map[string]string{"Content-Type": "text/html"}, - `Contact

Contact

`), - } - - checks := []Checker{ - &A11yCheck{}, - &PerfCheck{}, - &SEOCheck{}, - &SRICheck{}, - } - for _, chk := range checks { - findings := chk.Run(ctx, pages) - t.Logf("%s: %d findings across %d pages", chk.Name(), len(findings), len(pages)) - } -} - -// ===================================================================== -// Helper function tests -// ===================================================================== - -func TestContainsVersion(t *testing.T) { - if !containsVersion("Apache/2.4.51") { - t.Error("should detect version in Apache/2.4.51") - } - if !containsVersion("nginx/1.21.0") { - t.Error("should detect version in nginx/1.21.0") - } - if containsVersion("Apache") { - t.Error("should not detect version in plain Apache") - } - if containsVersion("") { - t.Error("should not detect version in empty string") - } -} - -func TestTruncate(t *testing.T) { - if truncate("short", 10) != "short" { - t.Error("short string should not be truncated") - } - if truncate("a long string here", 6) != "a long..." { - t.Errorf("expected truncation, got %q", truncate("a long string here", 6)) - } - if truncate("", 5) != "" { - t.Error("empty string should remain empty") - } -} - -func TestTruncateResRef(t *testing.T) { - if truncateResRef("/short.js", 80) != "/short.js" { - t.Error("short ref should not be truncated") - } - long := make([]byte, 100) - for i := range long { - long[i] = 'a' - } - got := truncateResRef(string(long), 80) - if len(got) != 83 { // 80 + "..." - t.Errorf("expected truncation to 83 chars, got %d", len(got)) - } -} - -func TestNormalizeForLookup(t *testing.T) { - tests := []struct { - input, want string - }{ - {"https://example.com/path/", "https://example.com/path"}, - {"https://example.com/path#frag", "https://example.com/path"}, - {"https://example.com", "https://example.com"}, - } - for _, tt := range tests { - got := normalizeForLookup(tt.input) - if got != tt.want { - t.Errorf("normalizeForLookup(%q) = %q, want %q", tt.input, got, tt.want) - } - } -} - -func TestIsSessionCookieName(t *testing.T) { - tests := []struct { - name string - want bool - }{ - {"session_id", true}, - {"JSESSIONID", true}, - {"PHPSESSID", true}, - {"connect.sid", true}, - {"auth_token", true}, - {"theme", false}, - {"lang", false}, - {"preferences", false}, - } - for _, tt := range tests { - got := isSessionCookieName(tt.name) - if got != tt.want { - t.Errorf("isSessionCookieName(%q) = %v, want %v", tt.name, got, tt.want) - } - } -} +// Note: additional tests for the check package live in check_more_test.go +// and check_extra_test.go (split out for file size/clarity). From 262c182ae52fb6b231855c5a163cf159cd43c3fd Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 21:00:34 +0530 Subject: [PATCH 2/4] test(inspect): split crawler_test.go into focused test files Mechanical refactor for code clarity. crawler_test.go (1549 LOC) split into two same-package files at a top-level function boundary with no behavior, API, or exported-symbol changes: - crawler_test.go (774 LOC): crawl, rate limiter, robots, resolve tests - crawler_more_test.go (790 LOC): auth, redirect, SSRF, page, ctor tests --- internal/crawler/crawler_more_test.go | 790 ++++++++++++++++++++++++++ internal/crawler/crawler_test.go | 779 +------------------------ 2 files changed, 792 insertions(+), 777 deletions(-) create mode 100644 internal/crawler/crawler_more_test.go diff --git a/internal/crawler/crawler_more_test.go b/internal/crawler/crawler_more_test.go new file mode 100644 index 0000000..bd261be --- /dev/null +++ b/internal/crawler/crawler_more_test.go @@ -0,0 +1,790 @@ +package crawler + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// This file was split out of crawler_test.go for readability (mechanical move; no behavior change). + +// --- Auth required tests --- + +func TestCrawl_AuthRequired_401(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusUnauthorized) + fmt.Fprint(w, `Unauthorized`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 1, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + if len(pages) == 0 { + t.Fatal("expected at least 1 page") + } + if pages[0].StatusCode != 401 { + t.Errorf("expected status 401, got %d", pages[0].StatusCode) + } + if !pages[0].AuthRequired { + t.Error("expected AuthRequired=true for 401 response") + } + if pages[0].Error != nil { + t.Errorf("expected no error for auth-required page, got %v", pages[0].Error) + } +} + +func TestCrawl_AuthRequired_403(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusForbidden) + fmt.Fprint(w, `Forbidden`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 1, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + if len(pages) == 0 { + t.Fatal("expected at least 1 page") + } + if pages[0].StatusCode != 403 { + t.Errorf("expected status 403, got %d", pages[0].StatusCode) + } + if !pages[0].AuthRequired { + t.Error("expected AuthRequired=true for 403 response") + } +} + +// --- Non-HTML content type --- + +func TestCrawl_NonHTMLContentType(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"key":"value"}`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 1, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + if len(pages) == 0 { + t.Fatal("expected at least 1 page") + } + // Body should be empty for non-HTML content + if len(pages[0].Body) != 0 { + t.Errorf("expected empty body for non-HTML, got %d bytes", len(pages[0].Body)) + } + if len(pages[0].Links) != 0 { + t.Errorf("expected no links for non-HTML, got %d", len(pages[0].Links)) + } +} + +// --- Exclude patterns --- + +func TestCrawl_ExcludePatterns(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, ` + public + admin + `) + }) + mux.HandleFunc("/public", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `public`) + }) + mux.HandleFunc("/admin/secret", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `secret`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 2, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + Exclude: []string{"/admin"}, + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + for _, p := range pages { + if strings.Contains(p.URL, "/admin") { + t.Errorf("should not have crawled excluded URL: %s", p.URL) + } + } +} + +// --- SSRF protection tests --- + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + ip string + expected bool + }{ + {"10.0.0.1", true}, + {"10.255.255.255", true}, + {"172.16.0.1", true}, + {"172.31.255.255", true}, + {"192.168.1.1", true}, + {"127.0.0.1", true}, + {"::1", true}, + {"8.8.8.8", false}, + {"1.1.1.1", false}, + {"93.184.216.34", false}, // example.com + } + + for _, tt := range tests { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("failed to parse IP: %s", tt.ip) + } + got := isPrivateIP(ip) + if got != tt.expected { + t.Errorf("isPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected) + } + } +} + +func TestValidateURL_SSRFProtection(t *testing.T) { + c := New(Config{ + AllowPrivateIPs: false, + UserAgent: "test", + }) + + // Non-http/https scheme should be rejected + err := c.validateURL("ftp://example.com/file") + if err == nil { + t.Error("expected error for ftp scheme") + } + + err = c.validateURL("file:///etc/passwd") + if err == nil { + t.Error("expected error for file scheme") + } + + // AllowPrivateIPs=true should skip IP checks + c2 := New(Config{ + AllowPrivateIPs: true, + UserAgent: "test", + }) + if err := c2.validateURL("http://127.0.0.1/admin"); err != nil { + t.Errorf("expected no error with AllowPrivateIPs=true, got %v", err) + } +} + +// --- isRetryable tests --- + +func TestIsRetryable(t *testing.T) { + tests := []struct { + status int + expected bool + }{ + {200, false}, + {301, false}, + {400, false}, + {401, false}, + {403, false}, + {404, false}, + {429, true}, + {500, true}, + {502, true}, + {503, true}, + {504, true}, + {0, true}, + } + for _, tt := range tests { + got := isRetryable(tt.status) + if got != tt.expected { + t.Errorf("isRetryable(%d) = %v, want %v", tt.status, got, tt.expected) + } + } +} + +// --- tryMarkSeen deduplication --- + +func TestTryMarkSeen_Dedup(t *testing.T) { + c := New(Config{UserAgent: "test"}) + + if !c.tryMarkSeen("https://example.com/page") { + t.Error("first call should return true") + } + if c.tryMarkSeen("https://example.com/page") { + t.Error("second call should return false (duplicate)") + } + // Same URL with different fragment should still be deduplicated + if c.tryMarkSeen("https://example.com/page#section") { + t.Error("same URL with fragment should be deduplicated") + } +} + +// --- isExcluded tests --- + +func TestIsExcluded(t *testing.T) { + c := New(Config{ + Exclude: []string{"/admin", ".pdf", "logout"}, + }) + + tests := []struct { + url string + expected bool + }{ + {"https://example.com/page", false}, + {"https://example.com/admin/users", true}, + {"https://example.com/doc.pdf", true}, + {"https://example.com/logout", true}, + {"https://example.com/public", false}, + } + for _, tt := range tests { + got := c.isExcluded(tt.url) + if got != tt.expected { + t.Errorf("isExcluded(%q) = %v, want %v", tt.url, got, tt.expected) + } + } +} + +// --- Crawl with robots.txt compliance --- + +func TestCrawl_RespectRobots(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, ` + ok + blocked + `) + }) + mux.HandleFunc("/allowed", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `allowed`) + }) + mux.HandleFunc("/blocked", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `blocked`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprint(w, "User-agent: *\nDisallow: /blocked\n") + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 2, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + RespectRobots: true, + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + for _, p := range pages { + if strings.Contains(p.URL, "/blocked") { + t.Errorf("should not have crawled robots-disallowed URL: %s", p.URL) + } + } +} + +// --- Crawl with auth header --- + +func TestCrawl_AuthHeader(t *testing.T) { + var gotHeader string + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + gotHeader = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `ok`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 1, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AuthHeader: "Authorization", + AuthValue: "Bearer secret-token", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + if len(pages) == 0 { + t.Fatal("expected at least 1 page") + } + if gotHeader != "Bearer secret-token" { + t.Errorf("expected auth header 'Bearer secret-token', got %q", gotHeader) + } +} + +// --- Crawl with server error --- + +func TestCrawl_ServerError(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 1, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + RetryAttempts: 0, // no retries for speed + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + if len(pages) == 0 { + t.Fatal("expected at least 1 page") + } + if pages[0].StatusCode != 500 { + t.Errorf("expected status 500, got %d", pages[0].StatusCode) + } +} + +// --- Page struct field initialization --- + +func TestPage_FieldsInitialized(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("X-Custom", "value") + fmt.Fprint(w, `link`) + }) + mux.HandleFunc("/child", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `child`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 2, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + // Find the root page + var root *Page + for _, p := range pages { + if p.URL == srv.URL+"/" || p.URL == srv.URL { + root = p + break + } + } + if root == nil { + t.Fatal("could not find root page") + } + + if root.StatusCode != 200 { + t.Errorf("expected StatusCode 200, got %d", root.StatusCode) + } + if root.Depth != 0 { + t.Errorf("expected Depth 0, got %d", root.Depth) + } + if root.ParentURL != "" { + t.Errorf("expected empty ParentURL, got %q", root.ParentURL) + } + if root.Duration <= 0 { + t.Error("expected positive Duration") + } + if root.Error != nil { + t.Errorf("expected no error, got %v", root.Error) + } + if root.Body == nil { + t.Error("expected non-nil Body") + } + if root.Headers == nil { + t.Error("expected non-nil Headers") + } + if root.Headers.Get("X-Custom") != "value" { + t.Error("expected X-Custom header") + } + if root.AuthRequired { + t.Error("expected AuthRequired=false for 200 response") + } +} + +func TestPage_ChildHasParentURL(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `child`) + }) + mux.HandleFunc("/child", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `child`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 2, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + var child *Page + for _, p := range pages { + if strings.HasSuffix(p.URL, "/child") { + child = p + break + } + } + if child == nil { + t.Fatal("could not find /child page") + } + + if child.Depth != 1 { + t.Errorf("expected child Depth 1, got %d", child.Depth) + } + if child.ParentURL == "" { + t.Error("expected non-empty ParentURL for child page") + } +} + +// --- New creates sensible defaults --- + +func TestNew_Defaults(t *testing.T) { + c := New(Config{UserAgent: "test"}) + if c.cfg.PageTimeout != 15*time.Second { + t.Errorf("expected default PageTimeout 15s, got %v", c.cfg.PageTimeout) + } + if c.cfg.RetryAttempts != 2 { + t.Errorf("expected default RetryAttempts 2, got %d", c.cfg.RetryAttempts) + } + if c.cfg.RetryDelay != 500*time.Millisecond { + t.Errorf("expected default RetryDelay 500ms, got %v", c.cfg.RetryDelay) + } + if c.client == nil { + t.Error("expected non-nil client") + } + if c.seen == nil { + t.Error("expected non-nil seen map") + } + if c.robots == nil { + t.Error("expected non-nil robots cache") + } + if c.limiter == nil { + t.Error("expected non-nil limiter") + } +} + +func TestNew_CustomValues(t *testing.T) { + c := New(Config{ + PageTimeout: 30 * time.Second, + RetryAttempts: 5, + RetryDelay: 1 * time.Second, + UserAgent: "custom-bot", + }) + if c.cfg.PageTimeout != 30*time.Second { + t.Errorf("expected PageTimeout 30s, got %v", c.cfg.PageTimeout) + } + if c.cfg.RetryAttempts != 5 { + t.Errorf("expected RetryAttempts 5, got %d", c.cfg.RetryAttempts) + } + if c.cfg.RetryDelay != 1*time.Second { + t.Errorf("expected RetryDelay 1s, got %v", c.cfg.RetryDelay) + } +} + +// --- Invalid URL for Crawl --- + +func TestCrawl_InvalidURL(t *testing.T) { + c := New(Config{UserAgent: "test"}) + _, err := c.Crawl(context.Background(), "://bad") + if err == nil { + t.Error("expected error for invalid URL") + } +} + +func TestCrawl_NoHost(t *testing.T) { + c := New(Config{UserAgent: "test"}) + _, err := c.Crawl(context.Background(), "http:///path") + if err == nil { + t.Error("expected error for URL with no host") + } +} + +// --- Crawl with zero MaxDepth (unlimited) --- + +func TestCrawl_ZeroMaxDepthMeansUnlimited(t *testing.T) { + // With MaxDepth=0, the crawler should follow links without depth limit + // (until it runs out of links or hits other limits) + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `p2`) + }) + mux.HandleFunc("/p2", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `p3`) + }) + mux.HandleFunc("/p3", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `leaf`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 0, // unlimited + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + if len(pages) < 3 { + t.Errorf("expected at least 3 pages with unlimited depth, got %d", len(pages)) + } +} + +// --- Crawl with link extraction from HTML --- + +func TestCrawl_ExtractsMailtoAndJavaScriptLinks(t *testing.T) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, ` + email + js + anchor + real + `) + }) + mux.HandleFunc("/real", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `real page`) + }) + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(404) + }) + + srv := httptest.NewServer(mux) + defer srv.Close() + + c := New(Config{ + MaxDepth: 2, + Concurrency: 1, + Timeout: 10 * time.Second, + PageTimeout: 5 * time.Second, + RateLimit: 100, + UserAgent: "test", + AllowPrivateIPs: true, + }) + + pages, err := c.Crawl(context.Background(), srv.URL) + if err != nil { + t.Fatalf("Crawl failed: %v", err) + } + + // Should have extracted mailto and javascript links on the page + var root *Page + for _, p := range pages { + if strings.HasSuffix(p.URL, "/") || p.URL == srv.URL || p.URL == srv.URL+"/" { + root = p + break + } + } + if root == nil { + t.Fatal("could not find root page") + } + + // Should not have crawled mailto: or javascript: URLs + for _, p := range pages { + if strings.HasPrefix(p.URL, "mailto:") || strings.HasPrefix(p.URL, "javascript:") { + t.Errorf("should not crawl mailto/javascript URL: %s", p.URL) + } + } + + // Root page links should include the mailto, javascript, and anchor links + linkHrefs := map[string]bool{} + for _, l := range root.Links { + linkHrefs[l.Href] = true + } + if !linkHrefs["mailto:user@example.com"] { + t.Error("expected mailto link in page links") + } + if !linkHrefs["javascript:alert(1)"] { + t.Error("expected javascript link in page links") + } + if !linkHrefs["#fragment"] { + t.Error("expected anchor link in page links") + } +} diff --git a/internal/crawler/crawler_test.go b/internal/crawler/crawler_test.go index 4e7ca7e..95dbde6 100644 --- a/internal/crawler/crawler_test.go +++ b/internal/crawler/crawler_test.go @@ -3,10 +3,8 @@ package crawler import ( "context" "fmt" - "net" "net/http" "net/http/httptest" - "strings" "testing" "time" ) @@ -772,778 +770,5 @@ func TestCrawl_RedirectLoop(t *testing.T) { } } -// --- Auth required tests --- - -func TestCrawl_AuthRequired_401(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(http.StatusUnauthorized) - fmt.Fprint(w, `Unauthorized`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 1, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - if len(pages) == 0 { - t.Fatal("expected at least 1 page") - } - if pages[0].StatusCode != 401 { - t.Errorf("expected status 401, got %d", pages[0].StatusCode) - } - if !pages[0].AuthRequired { - t.Error("expected AuthRequired=true for 401 response") - } - if pages[0].Error != nil { - t.Errorf("expected no error for auth-required page, got %v", pages[0].Error) - } -} - -func TestCrawl_AuthRequired_403(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - w.WriteHeader(http.StatusForbidden) - fmt.Fprint(w, `Forbidden`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 1, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - if len(pages) == 0 { - t.Fatal("expected at least 1 page") - } - if pages[0].StatusCode != 403 { - t.Errorf("expected status 403, got %d", pages[0].StatusCode) - } - if !pages[0].AuthRequired { - t.Error("expected AuthRequired=true for 403 response") - } -} - -// --- Non-HTML content type --- - -func TestCrawl_NonHTMLContentType(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - fmt.Fprint(w, `{"key":"value"}`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 1, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - if len(pages) == 0 { - t.Fatal("expected at least 1 page") - } - // Body should be empty for non-HTML content - if len(pages[0].Body) != 0 { - t.Errorf("expected empty body for non-HTML, got %d bytes", len(pages[0].Body)) - } - if len(pages[0].Links) != 0 { - t.Errorf("expected no links for non-HTML, got %d", len(pages[0].Links)) - } -} - -// --- Exclude patterns --- - -func TestCrawl_ExcludePatterns(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, ` - public - admin - `) - }) - mux.HandleFunc("/public", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `public`) - }) - mux.HandleFunc("/admin/secret", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `secret`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 2, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - Exclude: []string{"/admin"}, - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - for _, p := range pages { - if strings.Contains(p.URL, "/admin") { - t.Errorf("should not have crawled excluded URL: %s", p.URL) - } - } -} - -// --- SSRF protection tests --- - -func TestIsPrivateIP(t *testing.T) { - tests := []struct { - ip string - expected bool - }{ - {"10.0.0.1", true}, - {"10.255.255.255", true}, - {"172.16.0.1", true}, - {"172.31.255.255", true}, - {"192.168.1.1", true}, - {"127.0.0.1", true}, - {"::1", true}, - {"8.8.8.8", false}, - {"1.1.1.1", false}, - {"93.184.216.34", false}, // example.com - } - - for _, tt := range tests { - ip := net.ParseIP(tt.ip) - if ip == nil { - t.Fatalf("failed to parse IP: %s", tt.ip) - } - got := isPrivateIP(ip) - if got != tt.expected { - t.Errorf("isPrivateIP(%s) = %v, want %v", tt.ip, got, tt.expected) - } - } -} - -func TestValidateURL_SSRFProtection(t *testing.T) { - c := New(Config{ - AllowPrivateIPs: false, - UserAgent: "test", - }) - - // Non-http/https scheme should be rejected - err := c.validateURL("ftp://example.com/file") - if err == nil { - t.Error("expected error for ftp scheme") - } - - err = c.validateURL("file:///etc/passwd") - if err == nil { - t.Error("expected error for file scheme") - } - - // AllowPrivateIPs=true should skip IP checks - c2 := New(Config{ - AllowPrivateIPs: true, - UserAgent: "test", - }) - if err := c2.validateURL("http://127.0.0.1/admin"); err != nil { - t.Errorf("expected no error with AllowPrivateIPs=true, got %v", err) - } -} - -// --- isRetryable tests --- - -func TestIsRetryable(t *testing.T) { - tests := []struct { - status int - expected bool - }{ - {200, false}, - {301, false}, - {400, false}, - {401, false}, - {403, false}, - {404, false}, - {429, true}, - {500, true}, - {502, true}, - {503, true}, - {504, true}, - {0, true}, - } - for _, tt := range tests { - got := isRetryable(tt.status) - if got != tt.expected { - t.Errorf("isRetryable(%d) = %v, want %v", tt.status, got, tt.expected) - } - } -} - -// --- tryMarkSeen deduplication --- - -func TestTryMarkSeen_Dedup(t *testing.T) { - c := New(Config{UserAgent: "test"}) - - if !c.tryMarkSeen("https://example.com/page") { - t.Error("first call should return true") - } - if c.tryMarkSeen("https://example.com/page") { - t.Error("second call should return false (duplicate)") - } - // Same URL with different fragment should still be deduplicated - if c.tryMarkSeen("https://example.com/page#section") { - t.Error("same URL with fragment should be deduplicated") - } -} - -// --- isExcluded tests --- - -func TestIsExcluded(t *testing.T) { - c := New(Config{ - Exclude: []string{"/admin", ".pdf", "logout"}, - }) - - tests := []struct { - url string - expected bool - }{ - {"https://example.com/page", false}, - {"https://example.com/admin/users", true}, - {"https://example.com/doc.pdf", true}, - {"https://example.com/logout", true}, - {"https://example.com/public", false}, - } - for _, tt := range tests { - got := c.isExcluded(tt.url) - if got != tt.expected { - t.Errorf("isExcluded(%q) = %v, want %v", tt.url, got, tt.expected) - } - } -} - -// --- Crawl with robots.txt compliance --- - -func TestCrawl_RespectRobots(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, ` - ok - blocked - `) - }) - mux.HandleFunc("/allowed", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `allowed`) - }) - mux.HandleFunc("/blocked", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `blocked`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - fmt.Fprint(w, "User-agent: *\nDisallow: /blocked\n") - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 2, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - RespectRobots: true, - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - for _, p := range pages { - if strings.Contains(p.URL, "/blocked") { - t.Errorf("should not have crawled robots-disallowed URL: %s", p.URL) - } - } -} - -// --- Crawl with auth header --- - -func TestCrawl_AuthHeader(t *testing.T) { - var gotHeader string - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - gotHeader = r.Header.Get("Authorization") - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `ok`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 1, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AuthHeader: "Authorization", - AuthValue: "Bearer secret-token", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - if len(pages) == 0 { - t.Fatal("expected at least 1 page") - } - if gotHeader != "Bearer secret-token" { - t.Errorf("expected auth header 'Bearer secret-token', got %q", gotHeader) - } -} - -// --- Crawl with server error --- - -func TestCrawl_ServerError(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 1, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - RetryAttempts: 0, // no retries for speed - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - if len(pages) == 0 { - t.Fatal("expected at least 1 page") - } - if pages[0].StatusCode != 500 { - t.Errorf("expected status 500, got %d", pages[0].StatusCode) - } -} - -// --- Page struct field initialization --- - -func TestPage_FieldsInitialized(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - w.Header().Set("X-Custom", "value") - fmt.Fprint(w, `link`) - }) - mux.HandleFunc("/child", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `child`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 2, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - // Find the root page - var root *Page - for _, p := range pages { - if p.URL == srv.URL+"/" || p.URL == srv.URL { - root = p - break - } - } - if root == nil { - t.Fatal("could not find root page") - } - - if root.StatusCode != 200 { - t.Errorf("expected StatusCode 200, got %d", root.StatusCode) - } - if root.Depth != 0 { - t.Errorf("expected Depth 0, got %d", root.Depth) - } - if root.ParentURL != "" { - t.Errorf("expected empty ParentURL, got %q", root.ParentURL) - } - if root.Duration <= 0 { - t.Error("expected positive Duration") - } - if root.Error != nil { - t.Errorf("expected no error, got %v", root.Error) - } - if root.Body == nil { - t.Error("expected non-nil Body") - } - if root.Headers == nil { - t.Error("expected non-nil Headers") - } - if root.Headers.Get("X-Custom") != "value" { - t.Error("expected X-Custom header") - } - if root.AuthRequired { - t.Error("expected AuthRequired=false for 200 response") - } -} - -func TestPage_ChildHasParentURL(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `child`) - }) - mux.HandleFunc("/child", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `child`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 2, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - var child *Page - for _, p := range pages { - if strings.HasSuffix(p.URL, "/child") { - child = p - break - } - } - if child == nil { - t.Fatal("could not find /child page") - } - - if child.Depth != 1 { - t.Errorf("expected child Depth 1, got %d", child.Depth) - } - if child.ParentURL == "" { - t.Error("expected non-empty ParentURL for child page") - } -} - -// --- New creates sensible defaults --- - -func TestNew_Defaults(t *testing.T) { - c := New(Config{UserAgent: "test"}) - if c.cfg.PageTimeout != 15*time.Second { - t.Errorf("expected default PageTimeout 15s, got %v", c.cfg.PageTimeout) - } - if c.cfg.RetryAttempts != 2 { - t.Errorf("expected default RetryAttempts 2, got %d", c.cfg.RetryAttempts) - } - if c.cfg.RetryDelay != 500*time.Millisecond { - t.Errorf("expected default RetryDelay 500ms, got %v", c.cfg.RetryDelay) - } - if c.client == nil { - t.Error("expected non-nil client") - } - if c.seen == nil { - t.Error("expected non-nil seen map") - } - if c.robots == nil { - t.Error("expected non-nil robots cache") - } - if c.limiter == nil { - t.Error("expected non-nil limiter") - } -} - -func TestNew_CustomValues(t *testing.T) { - c := New(Config{ - PageTimeout: 30 * time.Second, - RetryAttempts: 5, - RetryDelay: 1 * time.Second, - UserAgent: "custom-bot", - }) - if c.cfg.PageTimeout != 30*time.Second { - t.Errorf("expected PageTimeout 30s, got %v", c.cfg.PageTimeout) - } - if c.cfg.RetryAttempts != 5 { - t.Errorf("expected RetryAttempts 5, got %d", c.cfg.RetryAttempts) - } - if c.cfg.RetryDelay != 1*time.Second { - t.Errorf("expected RetryDelay 1s, got %v", c.cfg.RetryDelay) - } -} - -// --- Invalid URL for Crawl --- - -func TestCrawl_InvalidURL(t *testing.T) { - c := New(Config{UserAgent: "test"}) - _, err := c.Crawl(context.Background(), "://bad") - if err == nil { - t.Error("expected error for invalid URL") - } -} - -func TestCrawl_NoHost(t *testing.T) { - c := New(Config{UserAgent: "test"}) - _, err := c.Crawl(context.Background(), "http:///path") - if err == nil { - t.Error("expected error for URL with no host") - } -} - -// --- Crawl with zero MaxDepth (unlimited) --- - -func TestCrawl_ZeroMaxDepthMeansUnlimited(t *testing.T) { - // With MaxDepth=0, the crawler should follow links without depth limit - // (until it runs out of links or hits other limits) - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `p2`) - }) - mux.HandleFunc("/p2", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `p3`) - }) - mux.HandleFunc("/p3", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `leaf`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 0, // unlimited - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - if len(pages) < 3 { - t.Errorf("expected at least 3 pages with unlimited depth, got %d", len(pages)) - } -} - -// --- Crawl with link extraction from HTML --- - -func TestCrawl_ExtractsMailtoAndJavaScriptLinks(t *testing.T) { - mux := http.NewServeMux() - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, ` - email - js - anchor - real - `) - }) - mux.HandleFunc("/real", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `real page`) - }) - mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - mux.HandleFunc("/sitemap.xml", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(404) - }) - - srv := httptest.NewServer(mux) - defer srv.Close() - - c := New(Config{ - MaxDepth: 2, - Concurrency: 1, - Timeout: 10 * time.Second, - PageTimeout: 5 * time.Second, - RateLimit: 100, - UserAgent: "test", - AllowPrivateIPs: true, - }) - - pages, err := c.Crawl(context.Background(), srv.URL) - if err != nil { - t.Fatalf("Crawl failed: %v", err) - } - - // Should have extracted mailto and javascript links on the page - var root *Page - for _, p := range pages { - if strings.HasSuffix(p.URL, "/") || p.URL == srv.URL || p.URL == srv.URL+"/" { - root = p - break - } - } - if root == nil { - t.Fatal("could not find root page") - } - - // Should not have crawled mailto: or javascript: URLs - for _, p := range pages { - if strings.HasPrefix(p.URL, "mailto:") || strings.HasPrefix(p.URL, "javascript:") { - t.Errorf("should not crawl mailto/javascript URL: %s", p.URL) - } - } - - // Root page links should include the mailto, javascript, and anchor links - linkHrefs := map[string]bool{} - for _, l := range root.Links { - linkHrefs[l.Href] = true - } - if !linkHrefs["mailto:user@example.com"] { - t.Error("expected mailto link in page links") - } - if !linkHrefs["javascript:alert(1)"] { - t.Error("expected javascript link in page links") - } - if !linkHrefs["#fragment"] { - t.Error("expected anchor link in page links") - } -} +// Note: additional tests for the crawler package live in crawler_more_test.go +// (split out for file size/clarity). From 3f92e9c3054bbd293abe3976aeeb2d4e28762af3 Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 21:01:20 +0530 Subject: [PATCH 3/4] test(inspect): split server_test.go into focused test files Mechanical refactor for code clarity. mcp/server_test.go (1314 LOC) split into two same-package files at a top-level function boundary with no behavior, API, or exported-symbol changes. Shared test helpers remain in server_test.go: - server_test.go (648 LOC): helpers, protocol, scan, scandir, error tests - server_more_test.go (689 LOC): session, concurrency, auth, timeout tests --- mcp/server_more_test.go | 689 ++++++++++++++++++++++++++++++++++++++++ mcp/server_test.go | 670 +------------------------------------- 2 files changed, 691 insertions(+), 668 deletions(-) create mode 100644 mcp/server_more_test.go diff --git a/mcp/server_more_test.go b/mcp/server_more_test.go new file mode 100644 index 0000000..d7a69ec --- /dev/null +++ b/mcp/server_more_test.go @@ -0,0 +1,689 @@ +package mcp + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/GrayCodeAI/inspect" + mcplib "github.com/mark3labs/mcp-go/mcp" + mcpserver "github.com/mark3labs/mcp-go/server" +) + +// This file was split out of server_test.go for readability (mechanical move; no behavior change). + +func TestErrorHandling_MissingSessionID(t *testing.T) { + ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) + initAndSession(t, ts) // create at least one session + + // Send a request without session header. + resp, err := postJSON(ts.URL, jsonRPCRequest(5, "tools/list", map[string]any{})) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + // The server should reject the request (400 or 404). + if resp.StatusCode == http.StatusOK { + t.Error("expected non-200 status for missing session ID") + } +} + +func TestErrorHandling_InvalidSessionID(t *testing.T) { + ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) + + resp, err := postSessionJSON(ts.URL, "fake-session-id-12345", jsonRPCRequest(5, "tools/list", map[string]any{})) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + t.Error("expected non-200 status for invalid session ID") + } +} + +func TestErrorHandling_ToolsCallMissingToolName(t *testing.T) { + ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) + sid := initAndSession(t, ts) + + // tools/call without "name" field + resp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(6, "tools/call", map[string]any{ + "arguments": map[string]any{"url": "http://example.com"}, + })) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + // Should get an error response (either JSON-RPC error or tool error). + if resp.StatusCode == http.StatusOK { + var rpcResp struct { + Error *struct { + Code int `json:"code"` + } `json:"error"` + Result *struct { + IsError bool `json:"isError"` + } `json:"result"` + } + if err := json.Unmarshal(body, &rpcResp); err == nil { + if rpcResp.Error == nil && rpcResp.Result != nil && !rpcResp.Result.IsError { + t.Error("expected error for missing tool name") + } + } + } +} + +func TestErrorHandling_ToolsCallUnknownTool(t *testing.T) { + ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) + sid := initAndSession(t, ts) + + resp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(7, "tools/call", map[string]any{ + "name": "nonexistent_tool", + "arguments": map[string]any{}, + })) + if err != nil { + t.Fatalf("request: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + var rpcResp struct { + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(body, &rpcResp); err == nil && rpcResp.Error != nil { + // Got a JSON-RPC level error for unknown tool -- expected. + if rpcResp.Error.Code == 0 { + t.Error("expected non-zero error code for unknown tool") + } + } else { + // If there was no JSON-RPC error, the HTTP status should indicate failure. + if resp.StatusCode == http.StatusOK { + t.Error("expected error response for unknown tool") + } + } +} + +func TestErrorHandling_WrongArgumentType(t *testing.T) { + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + req := mcplib.CallToolRequest{} + // url should be a string, not an int + req.Params.Arguments = map[string]interface{}{"url": 12345} + + result, err := s.handleScan(context.Background(), req) + if err != nil { + t.Fatal(err) + } + if result == nil { + t.Fatal("expected result") + } + // The int won't match the string type assertion, so strArg returns "". + if !result.IsError { + t.Fatal("expected error result when url is not a string") + } +} + +// --------------------------------------------------------------------------- +// 4. Concurrent requests +// --------------------------------------------------------------------------- + +func TestConcurrent_Scans(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `Concurrent

OK

`) + })) + defer ts.Close() + + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + + const n = 5 + var wg sync.WaitGroup + errs := make([]error, n) + results := make([]*mcplib.CallToolResult, n) + + for i := range n { + wg.Add(1) + go func(idx int) { + defer wg.Done() + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"url": ts.URL} + results[idx], errs[idx] = s.handleScan(context.Background(), req) + }(i) + } + wg.Wait() + + for i := range n { + if errs[i] != nil { + t.Errorf("goroutine %d: unexpected Go error: %v", i, errs[i]) + continue + } + if results[i] == nil { + t.Errorf("goroutine %d: nil result", i) + continue + } + if results[i].IsError { + t.Errorf("goroutine %d: unexpected error result", i) + } + } +} + +func TestConcurrent_ScanAndScanDir(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `Mixed

OK

`) + })) + defer ts.Close() + + dir := t.TempDir() + indexPath := filepath.Join(dir, "index.html") + if err := os.WriteFile(indexPath, []byte(`Dir

Dir

`), 0o644); err != nil { + t.Fatal(err) + } + + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + + const n = 4 + var wg sync.WaitGroup + errCh := make(chan error, n*2) + + // Launch concurrent scan and scan_dir calls. + for range n { + wg.Add(2) + go func() { + defer wg.Done() + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"url": ts.URL} + _, err := s.handleScan(context.Background(), req) + if err != nil { + errCh <- fmt.Errorf("handleScan: %w", err) + } + }() + go func() { + defer wg.Done() + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"path": dir} + _, err := s.handleScanDir(context.Background(), req) + if err != nil { + errCh <- fmt.Errorf("handleScanDir: %w", err) + } + }() + } + wg.Wait() + close(errCh) + + for err := range errCh { + t.Error(err) + } +} + +func TestConcurrent_DirectHandlerCalls(t *testing.T) { + // Ensure handleScan is safe when called from multiple goroutines with + // the same Server instance. This exercises internal scanner/crawler + // concurrency safety. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `Race

OK

`) + })) + defer ts.Close() + + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + + var wg sync.WaitGroup + var errCount atomic.Int64 + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"url": ts.URL} + result, err := s.handleScan(context.Background(), req) + if err != nil { + errCount.Add(1) + return + } + if result != nil && result.IsError { + errCount.Add(1) + } + }() + } + wg.Wait() + if n := errCount.Load(); n > 0 { + t.Errorf("%d out of 10 concurrent scans failed", n) + } +} + +// --------------------------------------------------------------------------- +// 5. Authentication -- verify auth options flow through to scanner +// --------------------------------------------------------------------------- + +func TestWithAuth_OptionPassthrough(t *testing.T) { + // Verify that auth header/value configured via options are actually used + // when the scanner makes HTTP requests. We start a test server that + // checks for the Authorization header. + var gotAuth string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `Auth

OK

`) + })) + defer ts.Close() + + s := New( + inspect.Quick, + inspect.WithAllowPrivateIPs(), + inspect.WithAuth("Authorization", "Bearer test-secret-token"), + ) + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"url": ts.URL} + + result, err := s.handleScan(context.Background(), req) + if err != nil { + t.Fatalf("handleScan: %v", err) + } + if result == nil { + t.Fatal("expected result") + } + if result.IsError { + t.Fatalf("unexpected error: %v", result.Content) + } + + if gotAuth != "Bearer test-secret-token" { + t.Errorf("expected auth header 'Bearer test-secret-token', got %q", gotAuth) + } +} + +func TestWithAuth_NoAuth(t *testing.T) { + // Verify that when no auth is configured, no Authorization header is sent. + var gotAuth string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `NoAuth

OK

`) + })) + defer ts.Close() + + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"url": ts.URL} + + result, err := s.handleScan(context.Background(), req) + if err != nil { + t.Fatalf("handleScan: %v", err) + } + if result == nil || result.IsError { + t.Fatal("expected successful result") + } + + if gotAuth != "" { + t.Errorf("expected no auth header, got %q", gotAuth) + } +} + +// --------------------------------------------------------------------------- +// 6. Timeout handling +// --------------------------------------------------------------------------- + +func TestTimeout_ScanContextCancelled(t *testing.T) { + // Start a server that delays its response beyond our context deadline. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, ``) + })) + defer ts.Close() + + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"url": ts.URL} + + result, err := s.handleScan(ctx, req) + if err != nil { + t.Fatalf("handler should not return Go error on timeout: %v", err) + } + if result == nil { + t.Fatal("expected result even on timeout") + } + // The scan should fail gracefully (tool-level error, not a panic or Go error). + if !result.IsError { + // It is possible the scan completed fast enough if the crawler + // has its own shorter page timeout. In that case, just log it. + t.Log("scan completed before context deadline; timeout not tested effectively") + } +} + +func TestTimeout_ScanDirContextCancelled(t *testing.T) { + dir := t.TempDir() + indexPath := filepath.Join(dir, "index.html") + if err := os.WriteFile(indexPath, []byte(`Timeout Dir

OK

`), 0o644); err != nil { + t.Fatal(err) + } + + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + // Give the context a moment to actually expire. + time.Sleep(5 * time.Millisecond) + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"path": dir} + + result, err := s.handleScanDir(ctx, req) + if err != nil { + t.Fatalf("handler should not return Go error on timeout: %v", err) + } + if result == nil { + t.Fatal("expected result even on timeout") + } +} + +func TestTimeout_WithTimeoutOption(t *testing.T) { + // Verify the WithTimeout option is accepted and the scanner is created + // without errors. We don't test an actual timeout here since the + // scanner's own timeout interacts with the context. + s := New( + inspect.Quick, + inspect.WithAllowPrivateIPs(), + inspect.WithTimeout(5*time.Second), + ) + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `TimeoutOpt

OK

`) + })) + defer ts.Close() + + req := mcplib.CallToolRequest{} + req.Params.Arguments = map[string]interface{}{"url": ts.URL} + + result, err := s.handleScan(context.Background(), req) + if err != nil { + t.Fatalf("handleScan: %v", err) + } + if result == nil || result.IsError { + t.Fatal("expected successful result with custom timeout") + } +} + +// --------------------------------------------------------------------------- +// 7. strArg helper +// --------------------------------------------------------------------------- + +func TestStrArg(t *testing.T) { + tests := []struct { + name string + args map[string]any + key string + want string + }{ + { + name: "present string", + args: map[string]any{"url": "http://example.com"}, + key: "url", + want: "http://example.com", + }, + { + name: "missing key", + args: map[string]any{}, + key: "url", + want: "", + }, + { + name: "nil arguments", + args: nil, + key: "url", + want: "", + }, + { + name: "wrong type", + args: map[string]any{"url": 42}, + key: "url", + want: "", + }, + { + name: "empty string", + args: map[string]any{"url": ""}, + key: "url", + want: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := mcplib.CallToolRequest{} + req.Params.Arguments = tc.args + got := strArg(req, tc.key) + if got != tc.want { + t.Errorf("strArg(%q) = %q, want %q", tc.key, got, tc.want) + } + }) + } +} + +// --------------------------------------------------------------------------- +// 8. Server construction +// --------------------------------------------------------------------------- + +func TestNew_ReturnsServer(t *testing.T) { + s := New() + if s == nil { + t.Fatal("New() returned nil") + } + if s.server == nil { + t.Fatal("internal MCPServer is nil") + } + if s.scanner == nil { + t.Fatal("internal scanner is nil") + } +} + +func TestNew_WithOptions(t *testing.T) { + s := New( + inspect.Deep, + inspect.WithConcurrency(5), + inspect.WithTimeout(30*time.Second), + inspect.WithAllowPrivateIPs(), + ) + if s == nil { + t.Fatal("New() returned nil") + } +} + +func TestNew_Presets(t *testing.T) { + presets := []struct { + name string + opt inspect.Option + }{ + {"Quick", inspect.Quick}, + {"Standard", inspect.Standard}, + {"Deep", inspect.Deep}, + {"SecurityOnly", inspect.SecurityOnly}, + {"CI", inspect.CI}, + } + + for _, p := range presets { + t.Run(p.name, func(t *testing.T) { + s := New(p.opt, inspect.WithAllowPrivateIPs()) + if s == nil { + t.Fatalf("New(%s) returned nil", p.name) + } + }) + } +} + +// --------------------------------------------------------------------------- +// 9. Full HTTP round-trip: initialize -> tools/list -> tools/call +// --------------------------------------------------------------------------- + +func TestFullRoundTrip(t *testing.T) { + // End-to-end test: initialize, list tools, call scan, verify result. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + fmt.Fprint(w, `E2E

E2E test

`) + })) + defer target.Close() + + ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) + + // Step 1: initialize + sid := initAndSession(t, ts) + + // Step 2: tools/list + listResp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(10, "tools/list", map[string]any{})) + if err != nil { + t.Fatalf("tools/list: %v", err) + } + listBody := readBody(t, listResp) + if listResp.StatusCode != http.StatusOK { + t.Fatalf("tools/list: status %d", listResp.StatusCode) + } + var listResult struct { + Result struct { + Tools []struct { + Name string `json:"name"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(listBody, &listResult); err != nil { + t.Fatalf("unmarshal tools/list: %v", err) + } + if len(listResult.Result.Tools) < 1 { + t.Fatal("expected at least 1 tool") + } + + // Step 3: tools/call on inspect_scan + callResp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(11, "tools/call", map[string]any{ + "name": "inspect_scan", + "arguments": map[string]any{"url": target.URL}, + })) + if err != nil { + t.Fatalf("tools/call: %v", err) + } + callBody := readBody(t, callResp) + if callResp.StatusCode != http.StatusOK { + t.Fatalf("tools/call: status %d, body: %s", callResp.StatusCode, callBody) + } + + var callResult struct { + Result struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + IsError bool `json:"isError"` + } `json:"result"` + } + if err := json.Unmarshal(callBody, &callResult); err != nil { + t.Fatalf("unmarshal tools/call: %v", err) + } + if callResult.Result.IsError { + t.Fatalf("tools/call returned error: %s", callResult.Result.Content[0].Text) + } + if len(callResult.Result.Content) == 0 { + t.Fatal("expected content in tools/call result") + } + + var report inspect.Report + if err := json.Unmarshal([]byte(callResult.Result.Content[0].Text), &report); err != nil { + t.Fatalf("invalid report JSON: %v", err) + } + if report.Target != target.URL { + t.Errorf("report target: want %s, got %s", target.URL, report.Target) + } +} + +// --------------------------------------------------------------------------- +// 10. Register tools verification +// --------------------------------------------------------------------------- + +func TestRegisterTools_InspectScanSchema(t *testing.T) { + s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) + + // Access the underlying MCPServer to verify tool registration. + mcpSrv := s.server + if mcpSrv == nil { + t.Fatal("internal MCPServer is nil") + } + + // We can verify tools are registered by calling tools/list through + // the server's handler. Use the HTTP test server for this. + ts := mcpserver.NewTestStreamableHTTPServer(mcpSrv, mcpserver.WithStateful(true)) + defer ts.Close() + + sid := initAndSession(t, ts) + + resp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(2, "tools/list", map[string]any{})) + if err != nil { + t.Fatalf("tools/list: %v", err) + } + body := readBody(t, resp) + + var result struct { + Result struct { + Tools []struct { + Name string `json:"name"` + Description string `json:"description"` + InputSchema struct { + Type string `json:"type"` + Properties map[string]any `json:"properties"` + Required []string `json:"required"` + } `json:"inputSchema"` + } `json:"tools"` + } `json:"result"` + } + if err := json.Unmarshal(body, &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + for _, tool := range result.Result.Tools { + if tool.Name == "inspect_scan" { + if tool.InputSchema.Type != "object" { + t.Errorf("inspect_scan input schema type: want object, got %s", tool.InputSchema.Type) + } + if _, ok := tool.InputSchema.Properties["url"]; !ok { + t.Error("inspect_scan missing 'url' property") + } + found := false + for _, r := range tool.InputSchema.Required { + if r == "url" { + found = true + } + } + if !found { + t.Error("inspect_scan 'url' property should be required") + } + } + if tool.Name == "inspect_scan_dir" { + if _, ok := tool.InputSchema.Properties["path"]; !ok { + t.Error("inspect_scan_dir missing 'path' property") + } + found := false + for _, r := range tool.InputSchema.Required { + if r == "path" { + found = true + } + } + if !found { + t.Error("inspect_scan_dir 'path' property should be required") + } + } + } +} diff --git a/mcp/server_test.go b/mcp/server_test.go index f85bc38..f2b2d74 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -11,8 +11,6 @@ import ( "os" "path/filepath" "strings" - "sync" - "sync/atomic" "testing" "time" @@ -646,669 +644,5 @@ func TestErrorHandling_UnknownMethod(t *testing.T) { } } -func TestErrorHandling_MissingSessionID(t *testing.T) { - ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) - initAndSession(t, ts) // create at least one session - - // Send a request without session header. - resp, err := postJSON(ts.URL, jsonRPCRequest(5, "tools/list", map[string]any{})) - if err != nil { - t.Fatalf("request: %v", err) - } - defer resp.Body.Close() - - // The server should reject the request (400 or 404). - if resp.StatusCode == http.StatusOK { - t.Error("expected non-200 status for missing session ID") - } -} - -func TestErrorHandling_InvalidSessionID(t *testing.T) { - ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) - - resp, err := postSessionJSON(ts.URL, "fake-session-id-12345", jsonRPCRequest(5, "tools/list", map[string]any{})) - if err != nil { - t.Fatalf("request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusOK { - t.Error("expected non-200 status for invalid session ID") - } -} - -func TestErrorHandling_ToolsCallMissingToolName(t *testing.T) { - ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) - sid := initAndSession(t, ts) - - // tools/call without "name" field - resp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(6, "tools/call", map[string]any{ - "arguments": map[string]any{"url": "http://example.com"}, - })) - if err != nil { - t.Fatalf("request: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - // Should get an error response (either JSON-RPC error or tool error). - if resp.StatusCode == http.StatusOK { - var rpcResp struct { - Error *struct { - Code int `json:"code"` - } `json:"error"` - Result *struct { - IsError bool `json:"isError"` - } `json:"result"` - } - if err := json.Unmarshal(body, &rpcResp); err == nil { - if rpcResp.Error == nil && rpcResp.Result != nil && !rpcResp.Result.IsError { - t.Error("expected error for missing tool name") - } - } - } -} - -func TestErrorHandling_ToolsCallUnknownTool(t *testing.T) { - ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) - sid := initAndSession(t, ts) - - resp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(7, "tools/call", map[string]any{ - "name": "nonexistent_tool", - "arguments": map[string]any{}, - })) - if err != nil { - t.Fatalf("request: %v", err) - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - var rpcResp struct { - Error *struct { - Code int `json:"code"` - Message string `json:"message"` - } `json:"error"` - } - if err := json.Unmarshal(body, &rpcResp); err == nil && rpcResp.Error != nil { - // Got a JSON-RPC level error for unknown tool -- expected. - if rpcResp.Error.Code == 0 { - t.Error("expected non-zero error code for unknown tool") - } - } else { - // If there was no JSON-RPC error, the HTTP status should indicate failure. - if resp.StatusCode == http.StatusOK { - t.Error("expected error response for unknown tool") - } - } -} - -func TestErrorHandling_WrongArgumentType(t *testing.T) { - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - req := mcplib.CallToolRequest{} - // url should be a string, not an int - req.Params.Arguments = map[string]interface{}{"url": 12345} - - result, err := s.handleScan(context.Background(), req) - if err != nil { - t.Fatal(err) - } - if result == nil { - t.Fatal("expected result") - } - // The int won't match the string type assertion, so strArg returns "". - if !result.IsError { - t.Fatal("expected error result when url is not a string") - } -} - -// --------------------------------------------------------------------------- -// 4. Concurrent requests -// --------------------------------------------------------------------------- - -func TestConcurrent_Scans(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `Concurrent

OK

`) - })) - defer ts.Close() - - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - - const n = 5 - var wg sync.WaitGroup - errs := make([]error, n) - results := make([]*mcplib.CallToolResult, n) - - for i := range n { - wg.Add(1) - go func(idx int) { - defer wg.Done() - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"url": ts.URL} - results[idx], errs[idx] = s.handleScan(context.Background(), req) - }(i) - } - wg.Wait() - - for i := range n { - if errs[i] != nil { - t.Errorf("goroutine %d: unexpected Go error: %v", i, errs[i]) - continue - } - if results[i] == nil { - t.Errorf("goroutine %d: nil result", i) - continue - } - if results[i].IsError { - t.Errorf("goroutine %d: unexpected error result", i) - } - } -} - -func TestConcurrent_ScanAndScanDir(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `Mixed

OK

`) - })) - defer ts.Close() - - dir := t.TempDir() - indexPath := filepath.Join(dir, "index.html") - if err := os.WriteFile(indexPath, []byte(`Dir

Dir

`), 0o644); err != nil { - t.Fatal(err) - } - - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - - const n = 4 - var wg sync.WaitGroup - errCh := make(chan error, n*2) - - // Launch concurrent scan and scan_dir calls. - for range n { - wg.Add(2) - go func() { - defer wg.Done() - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"url": ts.URL} - _, err := s.handleScan(context.Background(), req) - if err != nil { - errCh <- fmt.Errorf("handleScan: %w", err) - } - }() - go func() { - defer wg.Done() - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"path": dir} - _, err := s.handleScanDir(context.Background(), req) - if err != nil { - errCh <- fmt.Errorf("handleScanDir: %w", err) - } - }() - } - wg.Wait() - close(errCh) - - for err := range errCh { - t.Error(err) - } -} - -func TestConcurrent_DirectHandlerCalls(t *testing.T) { - // Ensure handleScan is safe when called from multiple goroutines with - // the same Server instance. This exercises internal scanner/crawler - // concurrency safety. - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `Race

OK

`) - })) - defer ts.Close() - - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - - var wg sync.WaitGroup - var errCount atomic.Int64 - - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"url": ts.URL} - result, err := s.handleScan(context.Background(), req) - if err != nil { - errCount.Add(1) - return - } - if result != nil && result.IsError { - errCount.Add(1) - } - }() - } - wg.Wait() - if n := errCount.Load(); n > 0 { - t.Errorf("%d out of 10 concurrent scans failed", n) - } -} - -// --------------------------------------------------------------------------- -// 5. Authentication -- verify auth options flow through to scanner -// --------------------------------------------------------------------------- - -func TestWithAuth_OptionPassthrough(t *testing.T) { - // Verify that auth header/value configured via options are actually used - // when the scanner makes HTTP requests. We start a test server that - // checks for the Authorization header. - var gotAuth string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotAuth = r.Header.Get("Authorization") - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `Auth

OK

`) - })) - defer ts.Close() - - s := New( - inspect.Quick, - inspect.WithAllowPrivateIPs(), - inspect.WithAuth("Authorization", "Bearer test-secret-token"), - ) - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"url": ts.URL} - - result, err := s.handleScan(context.Background(), req) - if err != nil { - t.Fatalf("handleScan: %v", err) - } - if result == nil { - t.Fatal("expected result") - } - if result.IsError { - t.Fatalf("unexpected error: %v", result.Content) - } - - if gotAuth != "Bearer test-secret-token" { - t.Errorf("expected auth header 'Bearer test-secret-token', got %q", gotAuth) - } -} - -func TestWithAuth_NoAuth(t *testing.T) { - // Verify that when no auth is configured, no Authorization header is sent. - var gotAuth string - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - gotAuth = r.Header.Get("Authorization") - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `NoAuth

OK

`) - })) - defer ts.Close() - - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"url": ts.URL} - - result, err := s.handleScan(context.Background(), req) - if err != nil { - t.Fatalf("handleScan: %v", err) - } - if result == nil || result.IsError { - t.Fatal("expected successful result") - } - - if gotAuth != "" { - t.Errorf("expected no auth header, got %q", gotAuth) - } -} - -// --------------------------------------------------------------------------- -// 6. Timeout handling -// --------------------------------------------------------------------------- - -func TestTimeout_ScanContextCancelled(t *testing.T) { - // Start a server that delays its response beyond our context deadline. - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - time.Sleep(5 * time.Second) - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, ``) - })) - defer ts.Close() - - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) - defer cancel() - - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"url": ts.URL} - - result, err := s.handleScan(ctx, req) - if err != nil { - t.Fatalf("handler should not return Go error on timeout: %v", err) - } - if result == nil { - t.Fatal("expected result even on timeout") - } - // The scan should fail gracefully (tool-level error, not a panic or Go error). - if !result.IsError { - // It is possible the scan completed fast enough if the crawler - // has its own shorter page timeout. In that case, just log it. - t.Log("scan completed before context deadline; timeout not tested effectively") - } -} - -func TestTimeout_ScanDirContextCancelled(t *testing.T) { - dir := t.TempDir() - indexPath := filepath.Join(dir, "index.html") - if err := os.WriteFile(indexPath, []byte(`Timeout Dir

OK

`), 0o644); err != nil { - t.Fatal(err) - } - - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) - defer cancel() - // Give the context a moment to actually expire. - time.Sleep(5 * time.Millisecond) - - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"path": dir} - - result, err := s.handleScanDir(ctx, req) - if err != nil { - t.Fatalf("handler should not return Go error on timeout: %v", err) - } - if result == nil { - t.Fatal("expected result even on timeout") - } -} - -func TestTimeout_WithTimeoutOption(t *testing.T) { - // Verify the WithTimeout option is accepted and the scanner is created - // without errors. We don't test an actual timeout here since the - // scanner's own timeout interacts with the context. - s := New( - inspect.Quick, - inspect.WithAllowPrivateIPs(), - inspect.WithTimeout(5*time.Second), - ) - - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `TimeoutOpt

OK

`) - })) - defer ts.Close() - - req := mcplib.CallToolRequest{} - req.Params.Arguments = map[string]interface{}{"url": ts.URL} - - result, err := s.handleScan(context.Background(), req) - if err != nil { - t.Fatalf("handleScan: %v", err) - } - if result == nil || result.IsError { - t.Fatal("expected successful result with custom timeout") - } -} - -// --------------------------------------------------------------------------- -// 7. strArg helper -// --------------------------------------------------------------------------- - -func TestStrArg(t *testing.T) { - tests := []struct { - name string - args map[string]any - key string - want string - }{ - { - name: "present string", - args: map[string]any{"url": "http://example.com"}, - key: "url", - want: "http://example.com", - }, - { - name: "missing key", - args: map[string]any{}, - key: "url", - want: "", - }, - { - name: "nil arguments", - args: nil, - key: "url", - want: "", - }, - { - name: "wrong type", - args: map[string]any{"url": 42}, - key: "url", - want: "", - }, - { - name: "empty string", - args: map[string]any{"url": ""}, - key: "url", - want: "", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - req := mcplib.CallToolRequest{} - req.Params.Arguments = tc.args - got := strArg(req, tc.key) - if got != tc.want { - t.Errorf("strArg(%q) = %q, want %q", tc.key, got, tc.want) - } - }) - } -} - -// --------------------------------------------------------------------------- -// 8. Server construction -// --------------------------------------------------------------------------- - -func TestNew_ReturnsServer(t *testing.T) { - s := New() - if s == nil { - t.Fatal("New() returned nil") - } - if s.server == nil { - t.Fatal("internal MCPServer is nil") - } - if s.scanner == nil { - t.Fatal("internal scanner is nil") - } -} - -func TestNew_WithOptions(t *testing.T) { - s := New( - inspect.Deep, - inspect.WithConcurrency(5), - inspect.WithTimeout(30*time.Second), - inspect.WithAllowPrivateIPs(), - ) - if s == nil { - t.Fatal("New() returned nil") - } -} - -func TestNew_Presets(t *testing.T) { - presets := []struct { - name string - opt inspect.Option - }{ - {"Quick", inspect.Quick}, - {"Standard", inspect.Standard}, - {"Deep", inspect.Deep}, - {"SecurityOnly", inspect.SecurityOnly}, - {"CI", inspect.CI}, - } - - for _, p := range presets { - t.Run(p.name, func(t *testing.T) { - s := New(p.opt, inspect.WithAllowPrivateIPs()) - if s == nil { - t.Fatalf("New(%s) returned nil", p.name) - } - }) - } -} - -// --------------------------------------------------------------------------- -// 9. Full HTTP round-trip: initialize -> tools/list -> tools/call -// --------------------------------------------------------------------------- - -func TestFullRoundTrip(t *testing.T) { - // End-to-end test: initialize, list tools, call scan, verify result. - target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/html") - fmt.Fprint(w, `E2E

E2E test

`) - })) - defer target.Close() - - ts, _ := newTestHTTPServer(t, inspect.Quick, inspect.WithAllowPrivateIPs()) - - // Step 1: initialize - sid := initAndSession(t, ts) - - // Step 2: tools/list - listResp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(10, "tools/list", map[string]any{})) - if err != nil { - t.Fatalf("tools/list: %v", err) - } - listBody := readBody(t, listResp) - if listResp.StatusCode != http.StatusOK { - t.Fatalf("tools/list: status %d", listResp.StatusCode) - } - var listResult struct { - Result struct { - Tools []struct { - Name string `json:"name"` - } `json:"tools"` - } `json:"result"` - } - if err := json.Unmarshal(listBody, &listResult); err != nil { - t.Fatalf("unmarshal tools/list: %v", err) - } - if len(listResult.Result.Tools) < 1 { - t.Fatal("expected at least 1 tool") - } - - // Step 3: tools/call on inspect_scan - callResp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(11, "tools/call", map[string]any{ - "name": "inspect_scan", - "arguments": map[string]any{"url": target.URL}, - })) - if err != nil { - t.Fatalf("tools/call: %v", err) - } - callBody := readBody(t, callResp) - if callResp.StatusCode != http.StatusOK { - t.Fatalf("tools/call: status %d, body: %s", callResp.StatusCode, callBody) - } - - var callResult struct { - Result struct { - Content []struct { - Type string `json:"type"` - Text string `json:"text"` - } `json:"content"` - IsError bool `json:"isError"` - } `json:"result"` - } - if err := json.Unmarshal(callBody, &callResult); err != nil { - t.Fatalf("unmarshal tools/call: %v", err) - } - if callResult.Result.IsError { - t.Fatalf("tools/call returned error: %s", callResult.Result.Content[0].Text) - } - if len(callResult.Result.Content) == 0 { - t.Fatal("expected content in tools/call result") - } - - var report inspect.Report - if err := json.Unmarshal([]byte(callResult.Result.Content[0].Text), &report); err != nil { - t.Fatalf("invalid report JSON: %v", err) - } - if report.Target != target.URL { - t.Errorf("report target: want %s, got %s", target.URL, report.Target) - } -} - -// --------------------------------------------------------------------------- -// 10. Register tools verification -// --------------------------------------------------------------------------- - -func TestRegisterTools_InspectScanSchema(t *testing.T) { - s := New(inspect.Quick, inspect.WithAllowPrivateIPs()) - - // Access the underlying MCPServer to verify tool registration. - mcpSrv := s.server - if mcpSrv == nil { - t.Fatal("internal MCPServer is nil") - } - - // We can verify tools are registered by calling tools/list through - // the server's handler. Use the HTTP test server for this. - ts := mcpserver.NewTestStreamableHTTPServer(mcpSrv, mcpserver.WithStateful(true)) - defer ts.Close() - - sid := initAndSession(t, ts) - - resp, err := postSessionJSON(ts.URL, sid, jsonRPCRequest(2, "tools/list", map[string]any{})) - if err != nil { - t.Fatalf("tools/list: %v", err) - } - body := readBody(t, resp) - - var result struct { - Result struct { - Tools []struct { - Name string `json:"name"` - Description string `json:"description"` - InputSchema struct { - Type string `json:"type"` - Properties map[string]any `json:"properties"` - Required []string `json:"required"` - } `json:"inputSchema"` - } `json:"tools"` - } `json:"result"` - } - if err := json.Unmarshal(body, &result); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - for _, tool := range result.Result.Tools { - if tool.Name == "inspect_scan" { - if tool.InputSchema.Type != "object" { - t.Errorf("inspect_scan input schema type: want object, got %s", tool.InputSchema.Type) - } - if _, ok := tool.InputSchema.Properties["url"]; !ok { - t.Error("inspect_scan missing 'url' property") - } - found := false - for _, r := range tool.InputSchema.Required { - if r == "url" { - found = true - } - } - if !found { - t.Error("inspect_scan 'url' property should be required") - } - } - if tool.Name == "inspect_scan_dir" { - if _, ok := tool.InputSchema.Properties["path"]; !ok { - t.Error("inspect_scan_dir missing 'path' property") - } - found := false - for _, r := range tool.InputSchema.Required { - if r == "path" { - found = true - } - } - if !found { - t.Error("inspect_scan_dir 'path' property should be required") - } - } - } -} +// Note: additional tests for the mcp package live in server_more_test.go +// (split out for file size/clarity). Shared test helpers remain in this file. From f5959190d211dd080831ce5c847b3cd9120e89cd Mon Sep 17 00:00:00 2001 From: Lakshman Patel Date: Fri, 19 Jun 2026 21:02:37 +0530 Subject: [PATCH 4/4] test(inspect): extract accessibility tests into check_a11y_test.go Mechanical follow-up to keep every check test file under 1000 LOC. Moves the ARIA / landmarks / advanced-a11y tests verbatim out of check_test.go into a focused same-package file. No behavior, API, or exported-symbol changes: - check_test.go (700 LOC) - check_a11y_test.go (441 LOC) --- internal/check/check_a11y_test.go | 441 ++++++++++++++++++++++++++++++ internal/check/check_test.go | 434 +---------------------------- 2 files changed, 443 insertions(+), 432 deletions(-) create mode 100644 internal/check/check_a11y_test.go diff --git a/internal/check/check_a11y_test.go b/internal/check/check_a11y_test.go new file mode 100644 index 0000000..f8aca28 --- /dev/null +++ b/internal/check/check_a11y_test.go @@ -0,0 +1,441 @@ +package check + +import ( + "context" + "strings" + "testing" + + "github.com/GrayCodeAI/inspect/internal/crawler" +) + +// This file was split out of check_test.go for readability (mechanical move; no behavior change). + +// --- Advanced A11y Tests --- + +func TestCheckARIA_InvalidRole(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
Content
`) + + findings := checkARIA(page) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Invalid ARIA role") && strings.Contains(f.Message, "banana") { + found = true + if f.Severity != SeverityMedium { + t.Errorf("expected medium severity for invalid role, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected finding for invalid ARIA role 'banana'") + } +} + +func TestCheckARIA_ValidRole(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
Nav content
`) + + findings := checkARIA(page) + + for _, f := range findings { + if strings.Contains(f.Message, "Invalid ARIA role") { + t.Errorf("should not flag valid ARIA role: %s", f.Message) + } + } +} + +func TestCheckARIA_PositiveTabindex(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
Content
`) + + findings := checkARIA(page) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Positive tabindex") { + found = true + if f.Severity != SeverityMedium { + t.Errorf("expected medium severity, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected finding for positive tabindex") + } +} + +func TestCheckARIA_ZeroTabindex(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
Content
`) + + findings := checkARIA(page) + + for _, f := range findings { + if strings.Contains(f.Message, "Positive tabindex") { + t.Error("tabindex=0 should not be flagged") + } + } +} + +func TestCheckARIA_NegativeTabindex(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
Content
`) + + findings := checkARIA(page) + + for _, f := range findings { + if strings.Contains(f.Message, "Positive tabindex") { + t.Error("tabindex=-1 should not be flagged as positive tabindex") + } + } +} + +func TestCheckARIA_AriaHiddenOnFocusable(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + findings := checkARIA(page) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Focusable element is aria-hidden") { + found = true + if f.Severity != SeverityHigh { + t.Errorf("expected high severity, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected finding for aria-hidden on focusable element") + } +} + +func TestCheckARIA_AriaHiddenOnNonFocusable(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + findings := checkARIA(page) + + for _, f := range findings { + if strings.Contains(f.Message, "Focusable element is aria-hidden") { + t.Error("should not flag aria-hidden on non-focusable element") + } + } +} + +func TestCheckARIA_InteractiveElementRemovedFromTabOrder(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ``) + + findings := checkARIA(page) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "Interactive element removed from tab order") { + found = true + } + } + if !found { + t.Error("expected finding for interactive element with tabindex=-1") + } +} + +func TestCheckARIA_RoleRequiringNameWithoutName(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
`) + + findings := checkARIA(page) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "has no accessible name") { + found = true + if f.Severity != SeverityHigh { + t.Errorf("expected high severity, got %v", f.Severity) + } + } + } + if !found { + t.Error("expected finding for role=button without accessible name") + } +} + +func TestCheckARIA_RoleRequiringNameWithAriaLabel(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
`) + + findings := checkARIA(page) + + for _, f := range findings { + if strings.Contains(f.Message, "has no accessible name") { + t.Error("should not flag element with aria-label") + } + } +} + +func TestCheckARIA_EmptyBody(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, "") + findings := checkARIA(page) + if len(findings) != 0 { + t.Error("should not produce findings for empty body") + } +} + +func TestCheckLandmarks_AllPresent(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + ` +
Header
+ +
Content
+ +
Footer
+ `) + + findings := checkLandmarks(page) + + for _, f := range findings { + if strings.Contains(f.Message, "missing") { + t.Errorf("should not flag missing landmark when all present: %s", f.Message) + } + } +} + +func TestCheckLandmarks_MissingNav(t *testing.T) { + page := makePage("https://example.com", 200, map[string]string{"Content-Type": "text/html"}, + `
H
Content
F
`) + + findings := checkLandmarks(page) + + found := false + for _, f := range findings { + if strings.Contains(f.Message, "missing