diff --git a/runtime/app/app.go b/runtime/app/app.go index 31c2a721..6fe11f1b 100644 --- a/runtime/app/app.go +++ b/runtime/app/app.go @@ -8,6 +8,7 @@ import ( "encoding/json" "html" "io/fs" + "net" "net/http" "net/url" "os" @@ -499,6 +500,9 @@ func CSRFInjectHTML(response http.ResponseWriter, request *http.Request, payload if !formStartTagHasPostMethod(tag) { continue } + if !formStartTagHasSameOriginAction(request, tag) { + continue + } if token == "" { generated, err := source.Token(response, request) if err != nil { @@ -543,13 +547,85 @@ func formStartTagRanges(payload []byte) [][2]int { } func formStartTagHasPostMethod(tag []byte) bool { + value, ok := formStartTagAttrValue(tag, "method") + return ok && strings.EqualFold(html.UnescapeString(string(value)), http.MethodPost) +} + +func formStartTagHasSameOriginAction(request *http.Request, tag []byte) bool { + value, ok := formStartTagAttrValue(tag, "action") + if !ok { + return true + } + action := strings.TrimSpace(html.UnescapeString(string(value))) + if action == "" { + return true + } + if formActionHasBrowserNetworkPath(action) { + return false + } + if request == nil { + return false + } + scheme, host := requestOrigin(request) + if scheme == "" || host == "" { + return false + } + actionURL, err := url.Parse(action) + if err != nil { + return false + } + resolved := (&url.URL{Scheme: scheme, Host: host, Path: "/"}).ResolveReference(actionURL) + return strings.EqualFold(resolved.Scheme, scheme) && + strings.EqualFold(canonicalOriginHost(resolved.Scheme, resolved.Host), canonicalOriginHost(scheme, host)) +} + +func requestOrigin(request *http.Request) (string, string) { + scheme := "" + if requestIsHTTPS(request) { + scheme = "https" + } else if request.URL != nil { + scheme = request.URL.Scheme + } + if scheme == "" { + scheme = "http" + } + host := request.Host + if host == "" && request.URL != nil { + host = request.URL.Host + } + return scheme, host +} + +func formActionHasBrowserNetworkPath(action string) bool { + if len(action) < 2 { + return false + } + first := action[0] + second := action[1] + return (first == '/' || first == '\\') && (second == '/' || second == '\\') +} + +func canonicalOriginHost(scheme string, host string) string { + host = strings.ToLower(strings.TrimSpace(host)) + name, port, err := net.SplitHostPort(host) + if err == nil { + defaultPort := (scheme == "http" && port == "80") || (scheme == "https" && port == "443") + if defaultPort { + return strings.ToLower(strings.Trim(name, "[]")) + } + return strings.ToLower(strings.Trim(name, "[]") + ":" + port) + } + return strings.Trim(host, "[]") +} + +func formStartTagAttrValue(tag []byte, attrName string) ([]byte, bool) { cursor := len("= len(tag) || tag[cursor] == '>' || tag[cursor] == '/' { - return false + return nil, false } nameStart := cursor for cursor < len(tag) && !isHTMLSpace(tag[cursor]) && tag[cursor] != '=' && tag[cursor] != '/' && tag[cursor] != '>' { @@ -568,11 +644,11 @@ func formStartTagHasPostMethod(tag []byte) bool { } value, next := htmlAttrValue(tag, cursor) cursor = next - if bytes.EqualFold(name, []byte("method")) && strings.EqualFold(string(value), http.MethodPost) { - return true + if bytes.EqualFold(name, []byte(attrName)) { + return value, true } } - return false + return nil, false } func htmlTagEnd(payload []byte, cursor int) int { diff --git a/runtime/app/app_test.go b/runtime/app/app_test.go index bae736e4..d2017784 100644 --- a/runtime/app/app_test.go +++ b/runtime/app/app_test.go @@ -1393,6 +1393,102 @@ func TestHandlerInjectsCSRFHiddenInputsIntoPOSTForms(t *testing.T) { } } +func TestCSRFInjectHTMLSkipsOffOriginPOSTFormActions(t *testing.T) { + csrf := &fakeCSRFTokenSource{field: "_csrf", token: "signed-token"} + payload := []byte(`
` + + `
` + + `
` + + `
` + + `
` + + `
` + + `
` + + `
` + + `
` + + `
`) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "http://example.com/page", nil) + + updated, ok := CSRFInjectHTML(recorder, request, payload, csrf) + + if !ok { + t.Fatal("expected csrf injection to succeed") + } + body := string(updated) + for _, external := range []string{ + `
`, + `
`, + `
`, + `
`, + `
`, + `
`, + } { + if !strings.Contains(body, external) { + t.Fatalf("expected off-origin form to remain without csrf input: %s", body) + } + } + for _, local := range []string{ + `
`, + ``, + } { + if !strings.Contains(body, local) { + t.Fatalf("expected same-origin form to receive csrf input %q in %s", local, body) + } + } + if count := strings.Count(body, `name="_csrf"`); count != 2 { + t.Fatalf("expected csrf input only for same-origin forms, got %d: %s", count, body) + } + if csrf.calls != 1 { + t.Fatalf("expected one token generation call, got %d", csrf.calls) + } + if cache := recorder.Header().Get("Cache-Control"); cache != "no-store" { + t.Fatalf("expected no-store for csrf-personalized HTML, got %q", cache) + } +} + +func TestCSRFInjectHTMLUsesForwardedHTTPSForSameOriginActions(t *testing.T) { + csrf := &fakeCSRFTokenSource{field: "_csrf", token: "signed-token"} + payload := []byte(`
`) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "http://example.com/page", nil) + request.Header.Set("X-Forwarded-Proto", "https") + + updated, ok := CSRFInjectHTML(recorder, request, payload, csrf) + + if !ok { + t.Fatal("expected csrf injection to succeed") + } + body := string(updated) + expected := `
` + if !strings.Contains(body, expected) { + t.Fatalf("expected forwarded HTTPS same-origin form to receive csrf input, got %s", body) + } + if csrf.calls != 1 { + t.Fatalf("expected one token generation call, got %d", csrf.calls) + } +} + +func TestCSRFInjectHTMLDoesNotGenerateTokenForOnlyOffOriginPOSTForms(t *testing.T) { + csrf := &fakeCSRFTokenSource{field: "_csrf", token: "signed-token"} + payload := []byte(`
`) + recorder := httptest.NewRecorder() + request := httptest.NewRequest(http.MethodGet, "http://example.com/page", nil) + + updated, ok := CSRFInjectHTML(recorder, request, payload, csrf) + + if !ok { + t.Fatal("expected csrf injection to succeed") + } + if string(updated) != string(payload) { + t.Fatalf("expected off-origin form payload to remain unchanged, got %s", updated) + } + if csrf.calls != 0 { + t.Fatalf("expected no token generation for off-origin form, got %d", csrf.calls) + } + if cache := recorder.Header().Get("Cache-Control"); cache != "" { + t.Fatalf("expected no cache-control mutation without injected token, got %q", cache) + } +} + func TestHandlerReturnsNoStoreErrorWhenCSRFTokenGenerationFails(t *testing.T) { handler := Handler{ Root: fstest.MapFS{