diff --git a/runtime/app/app.go b/runtime/app/app.go index 31c2a721..6fe11f1b 100644 --- a/runtime/app/app.go +++ b/runtime/app/app.go @@ -8,6 +8,7 @@ import ( "encoding/json" "html" "io/fs" + "net" "net/http" "net/url" "os" @@ -499,6 +500,9 @@ func CSRFInjectHTML(response http.ResponseWriter, request *http.Request, payload if !formStartTagHasPostMethod(tag) { continue } + if !formStartTagHasSameOriginAction(request, tag) { + continue + } if token == "" { generated, err := source.Token(response, request) if err != nil { @@ -543,13 +547,85 @@ func formStartTagRanges(payload []byte) [][2]int { } func formStartTagHasPostMethod(tag []byte) bool { + value, ok := formStartTagAttrValue(tag, "method") + return ok && strings.EqualFold(html.UnescapeString(string(value)), http.MethodPost) +} + +func formStartTagHasSameOriginAction(request *http.Request, tag []byte) bool { + value, ok := formStartTagAttrValue(tag, "action") + if !ok { + return true + } + action := strings.TrimSpace(html.UnescapeString(string(value))) + if action == "" { + return true + } + if formActionHasBrowserNetworkPath(action) { + return false + } + if request == nil { + return false + } + scheme, host := requestOrigin(request) + if scheme == "" || host == "" { + return false + } + actionURL, err := url.Parse(action) + if err != nil { + return false + } + resolved := (&url.URL{Scheme: scheme, Host: host, Path: "/"}).ResolveReference(actionURL) + return strings.EqualFold(resolved.Scheme, scheme) && + strings.EqualFold(canonicalOriginHost(resolved.Scheme, resolved.Host), canonicalOriginHost(scheme, host)) +} + +func requestOrigin(request *http.Request) (string, string) { + scheme := "" + if requestIsHTTPS(request) { + scheme = "https" + } else if request.URL != nil { + scheme = request.URL.Scheme + } + if scheme == "" { + scheme = "http" + } + host := request.Host + if host == "" && request.URL != nil { + host = request.URL.Host + } + return scheme, host +} + +func formActionHasBrowserNetworkPath(action string) bool { + if len(action) < 2 { + return false + } + first := action[0] + second := action[1] + return (first == '/' || first == '\\') && (second == '/' || second == '\\') +} + +func canonicalOriginHost(scheme string, host string) string { + host = strings.ToLower(strings.TrimSpace(host)) + name, port, err := net.SplitHostPort(host) + if err == nil { + defaultPort := (scheme == "http" && port == "80") || (scheme == "https" && port == "443") + if defaultPort { + return strings.ToLower(strings.Trim(name, "[]")) + } + return strings.ToLower(strings.Trim(name, "[]") + ":" + port) + } + return strings.Trim(host, "[]") +} + +func formStartTagAttrValue(tag []byte, attrName string) ([]byte, bool) { cursor := len("