|
| 1 | +// Copyright (c) 2026 Bryan Frimin <bryan@frimin.fr>. |
| 2 | +// |
| 3 | +// Permission to use, copy, modify, and/or distribute this software |
| 4 | +// for any purpose with or without fee is hereby granted, provided |
| 5 | +// that the above copyright notice and this permission notice appear |
| 6 | +// in all copies. |
| 7 | +// |
| 8 | +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL |
| 9 | +// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED |
| 10 | +// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE |
| 11 | +// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR |
| 12 | +// CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS |
| 13 | +// OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, |
| 14 | +// NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN |
| 15 | +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
| 16 | + |
| 17 | +package httpclient |
| 18 | + |
| 19 | +import ( |
| 20 | + "errors" |
| 21 | + "fmt" |
| 22 | + "net/http" |
| 23 | + "net/netip" |
| 24 | + "net/url" |
| 25 | + "strings" |
| 26 | + "syscall" |
| 27 | +) |
| 28 | + |
| 29 | +// ErrBlockedAddress is returned when a dial is rejected because the |
| 30 | +// resolved peer address belongs to a blocked range. |
| 31 | +var ErrBlockedAddress = errors.New("httpclient: address blocked by SSRF protection") |
| 32 | + |
| 33 | +// ErrCrossOriginRedirect is returned when a redirect target changes |
| 34 | +// scheme, host, or port and the client is configured to refuse such |
| 35 | +// redirects. |
| 36 | +var ErrCrossOriginRedirect = errors.New("httpclient: cross-origin redirect blocked by SSRF protection") |
| 37 | + |
| 38 | +// extraBlockedPrefixes are reserved or otherwise unsafe ranges that |
| 39 | +// netip.Addr's helper predicates do not cover. RFC 1918 (private), |
| 40 | +// loopback, link-local, multicast, and unspecified are handled by the |
| 41 | +// stdlib helpers in isBlockedAddr. |
| 42 | +var extraBlockedPrefixes = []netip.Prefix{ |
| 43 | + netip.MustParsePrefix("0.0.0.0/8"), // RFC 1122 "this network" |
| 44 | + netip.MustParsePrefix("100.64.0.0/10"), // RFC 6598 CGNAT |
| 45 | + netip.MustParsePrefix("192.0.0.0/24"), // RFC 6890 IETF protocol assignments |
| 46 | + netip.MustParsePrefix("192.0.2.0/24"), // RFC 5737 TEST-NET-1 |
| 47 | + netip.MustParsePrefix("198.18.0.0/15"), // RFC 2544 benchmarking |
| 48 | + netip.MustParsePrefix("198.51.100.0/24"), // RFC 5737 TEST-NET-2 |
| 49 | + netip.MustParsePrefix("203.0.113.0/24"), // RFC 5737 TEST-NET-3 |
| 50 | + netip.MustParsePrefix("240.0.0.0/4"), // RFC 1112 reserved (incl. 255.255.255.255) |
| 51 | + netip.MustParsePrefix("64:ff9b::/96"), // RFC 6052 IPv4/IPv6 translation |
| 52 | + netip.MustParsePrefix("64:ff9b:1::/48"), // RFC 8215 IPv4/IPv6 local translation |
| 53 | + netip.MustParsePrefix("100::/64"), // RFC 6666 discard prefix |
| 54 | + netip.MustParsePrefix("2001::/23"), // IETF protocol assignments |
| 55 | + netip.MustParsePrefix("2001:db8::/32"), // RFC 3849 documentation |
| 56 | +} |
| 57 | + |
| 58 | +// makeSSRFDialControl returns a net.Dialer.Control function that |
| 59 | +// rejects connections to IP addresses in private, loopback, |
| 60 | +// link-local, multicast, CGNAT, or other reserved ranges. The check |
| 61 | +// runs after DNS resolution on the actual peer address, which |
| 62 | +// defeats DNS rebinding between any prior URL validation and the |
| 63 | +// dial. |
| 64 | +// |
| 65 | +// When allowLoopback is true, 127.0.0.0/8 and ::1 are permitted. The |
| 66 | +// option is intended for tests that bind httptest servers to the |
| 67 | +// loopback interface and must not be used in production. |
| 68 | +// |
| 69 | +// Only TCP and UDP networks are permitted; unix sockets and other |
| 70 | +// network types are refused outright. |
| 71 | +func makeSSRFDialControl(allowLoopback bool) func(network, address string, _ syscall.RawConn) error { |
| 72 | + return func(network, address string, _ syscall.RawConn) error { |
| 73 | + switch network { |
| 74 | + case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": |
| 75 | + default: |
| 76 | + return fmt.Errorf("%w: refusing non-IP network %q", ErrBlockedAddress, network) |
| 77 | + } |
| 78 | + |
| 79 | + addrPort, err := netip.ParseAddrPort(address) |
| 80 | + if err != nil { |
| 81 | + return fmt.Errorf("%w: cannot parse peer address %q: %v", ErrBlockedAddress, address, err) |
| 82 | + } |
| 83 | + |
| 84 | + ip := addrPort.Addr() |
| 85 | + if ip.Is4In6() { |
| 86 | + ip = ip.Unmap() |
| 87 | + } |
| 88 | + |
| 89 | + if allowLoopback && ip.IsLoopback() { |
| 90 | + return nil |
| 91 | + } |
| 92 | + |
| 93 | + if isBlockedAddr(ip) { |
| 94 | + return fmt.Errorf("%w: %s", ErrBlockedAddress, ip) |
| 95 | + } |
| 96 | + |
| 97 | + return nil |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +// isBlockedAddr reports whether ip falls in a range that should be |
| 102 | +// refused for outbound HTTP from an SSRF perspective. |
| 103 | +func isBlockedAddr(ip netip.Addr) bool { |
| 104 | + if !ip.IsValid() { |
| 105 | + return true |
| 106 | + } |
| 107 | + |
| 108 | + // Unwrap IPv4-in-IPv6 (::ffff:a.b.c.d) so a single check covers |
| 109 | + // both representations. |
| 110 | + if ip.Is4In6() { |
| 111 | + ip = ip.Unmap() |
| 112 | + } |
| 113 | + |
| 114 | + if ip.IsUnspecified() || |
| 115 | + ip.IsLoopback() || |
| 116 | + ip.IsPrivate() || |
| 117 | + ip.IsLinkLocalUnicast() || |
| 118 | + ip.IsLinkLocalMulticast() || |
| 119 | + ip.IsMulticast() || |
| 120 | + ip.IsInterfaceLocalMulticast() { |
| 121 | + return true |
| 122 | + } |
| 123 | + |
| 124 | + for _, prefix := range extraBlockedPrefixes { |
| 125 | + if prefix.Contains(ip) { |
| 126 | + return true |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + return false |
| 131 | +} |
| 132 | + |
| 133 | +// noCrossOriginRedirects is an http.Client.CheckRedirect function that |
| 134 | +// blocks redirects whose scheme, host, or port differs from the |
| 135 | +// original request. It is paired with WithSSRFProtection on the |
| 136 | +// stdlib clients returned by DefaultClient and DefaultPooledClient |
| 137 | +// to defeat redirect-based pivots into internal services. |
| 138 | +func noCrossOriginRedirects(req *http.Request, via []*http.Request) error { |
| 139 | + if len(via) == 0 { |
| 140 | + return nil |
| 141 | + } |
| 142 | + original := via[0].URL |
| 143 | + if !sameOrigin(original, req.URL) { |
| 144 | + return fmt.Errorf("%w: %s -> %s", ErrCrossOriginRedirect, originString(original), originString(req.URL)) |
| 145 | + } |
| 146 | + return nil |
| 147 | +} |
| 148 | + |
| 149 | +func sameOrigin(a, b *url.URL) bool { |
| 150 | + return strings.EqualFold(a.Scheme, b.Scheme) && |
| 151 | + strings.EqualFold(a.Hostname(), b.Hostname()) && |
| 152 | + defaultedPort(a) == defaultedPort(b) |
| 153 | +} |
| 154 | + |
| 155 | +func originString(u *url.URL) string { |
| 156 | + return u.Scheme + "://" + u.Host |
| 157 | +} |
| 158 | + |
| 159 | +func defaultedPort(u *url.URL) string { |
| 160 | + if p := u.Port(); p != "" { |
| 161 | + return p |
| 162 | + } |
| 163 | + switch strings.ToLower(u.Scheme) { |
| 164 | + case "https": |
| 165 | + return "443" |
| 166 | + case "http": |
| 167 | + return "80" |
| 168 | + } |
| 169 | + return "" |
| 170 | +} |
0 commit comments