Skip to content

Commit d63abbf

Browse files
committed
Add custom dial for ssrf
Signed-off-by: Bryan Frimin <bryan@getprobo.com>
1 parent b0533a1 commit d63abbf

File tree

3 files changed

+408
-8
lines changed

3 files changed

+408
-8
lines changed

httpclient/httpclient.go

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ type (
4444
tracerProvider trace.TracerProvider
4545
logger *log.Logger
4646
registerer prometheus.Registerer
47+
48+
ssrfProtection bool
49+
ssrfAllowLoopback bool
4750
}
4851
)
4952

@@ -82,13 +85,47 @@ func WithRegisterer(r prometheus.Registerer) Option {
8285
}
8386
}
8487

88+
// WithSSRFProtection enables server-side request forgery protection
89+
// on the returned transport. When enabled, the underlying dialer
90+
// rejects connections to addresses that resolve to loopback,
91+
// private (RFC 1918), CGNAT (RFC 6598), link-local, multicast,
92+
// unspecified, IPv4-mapped IPv6 of the same, ULA (RFC 4193), and
93+
// other reserved ranges. The check runs on the actual peer
94+
// address at connect time, which defeats DNS rebinding between
95+
// any prior URL validation and the dial.
96+
//
97+
// When the resulting transport is built into an http.Client via
98+
// DefaultClient or DefaultPooledClient, the client is also
99+
// configured to refuse redirects whose scheme, host, or port
100+
// differs from the original request, blocking redirect-based
101+
// pivots.
102+
//
103+
// Use this option whenever the destination URL is influenced by
104+
// untrusted input (for example, customer-supplied webhook
105+
// endpoints or OAuth2 token URLs).
106+
func WithSSRFProtection() Option {
107+
return func(o *Options) {
108+
o.ssrfProtection = true
109+
}
110+
}
111+
112+
// WithSSRFAllowLoopback weakens WithSSRFProtection to permit dials to
113+
// loopback addresses (127.0.0.0/8 and ::1). It exists for tests that
114+
// stand up an httptest server on the loopback interface; production
115+
// callers must not use it.
116+
func WithSSRFAllowLoopback() Option {
117+
return func(o *Options) {
118+
o.ssrfAllowLoopback = true
119+
}
120+
}
121+
85122
// DefaultTransport returns a new http.Transport with similar default
86123
// values to http.DefaultTransport, but with idle connections and
87124
// keepalives disabled.
88125
func DefaultTransport(options ...Option) http.RoundTripper {
89126
opts := configureOptions(options)
90127

91-
transport := createBaseTransport()
128+
transport := createBaseTransport(opts)
92129
transport.DisableKeepAlives = true
93130
transport.MaxIdleConnsPerHost = -1
94131
transport.TLSClientConfig = opts.tlsConfig
@@ -104,7 +141,7 @@ func DefaultTransport(options ...Option) http.RoundTripper {
104141
func DefaultPooledTransport(options ...Option) http.RoundTripper {
105142
opts := configureOptions(options)
106143

107-
transport := createBaseTransport()
144+
transport := createBaseTransport(opts)
108145
transport.MaxIdleConnsPerHost = runtime.GOMAXPROCS(0) + 1
109146
transport.TLSClientConfig = opts.tlsConfig
110147

@@ -115,9 +152,8 @@ func DefaultPooledTransport(options ...Option) http.RoundTripper {
115152
// to http.Client, but with a non-shared Transport, idle connections
116153
// disabled, and keepalives disabled.
117154
func DefaultClient(options ...Option) *http.Client {
118-
return &http.Client{
119-
Transport: DefaultTransport(options...),
120-
}
155+
opts := configureOptions(options)
156+
return buildClient(DefaultTransport(options...), opts)
121157
}
122158

123159
// DefaultPooledClient returns a new http.Client with similar default
@@ -126,17 +162,27 @@ func DefaultClient(options ...Option) *http.Client {
126162
// time. Only use this for clients that will be re-used for the same
127163
// host(s).
128164
func DefaultPooledClient(options ...Option) *http.Client {
129-
return &http.Client{
130-
Transport: DefaultPooledTransport(options...),
165+
opts := configureOptions(options)
166+
return buildClient(DefaultPooledTransport(options...), opts)
167+
}
168+
169+
func buildClient(rt http.RoundTripper, opts *Options) *http.Client {
170+
c := &http.Client{Transport: rt}
171+
if opts.ssrfProtection {
172+
c.CheckRedirect = noCrossOriginRedirects
131173
}
174+
return c
132175
}
133176

134-
func createBaseTransport() *http.Transport {
177+
func createBaseTransport(opts *Options) *http.Transport {
135178
dial := &net.Dialer{
136179
Timeout: 30 * time.Second,
137180
KeepAlive: 30 * time.Second,
138181
DualStack: true,
139182
}
183+
if opts.ssrfProtection {
184+
dial.Control = makeSSRFDialControl(opts.ssrfAllowLoopback)
185+
}
140186

141187
return &http.Transport{
142188
Proxy: http.ProxyFromEnvironment,

httpclient/ssrf.go

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
// Copyright (c) 2026 Bryan Frimin <bryan@frimin.fr>.
2+
//
3+
// Permission to use, copy, modify, and/or distribute this software
4+
// for any purpose with or without fee is hereby granted, provided
5+
// that the above copyright notice and this permission notice appear
6+
// in all copies.
7+
//
8+
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL
9+
// WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED
10+
// WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE
11+
// AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, INDIRECT, OR
12+
// CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS
13+
// OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
14+
// NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
15+
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16+
17+
package httpclient
18+
19+
import (
20+
"errors"
21+
"fmt"
22+
"net/http"
23+
"net/netip"
24+
"net/url"
25+
"strings"
26+
"syscall"
27+
)
28+
29+
// ErrBlockedAddress is returned when a dial is rejected because the
30+
// resolved peer address belongs to a blocked range.
31+
var ErrBlockedAddress = errors.New("httpclient: address blocked by SSRF protection")
32+
33+
// ErrCrossOriginRedirect is returned when a redirect target changes
34+
// scheme, host, or port and the client is configured to refuse such
35+
// redirects.
36+
var ErrCrossOriginRedirect = errors.New("httpclient: cross-origin redirect blocked by SSRF protection")
37+
38+
// extraBlockedPrefixes are reserved or otherwise unsafe ranges that
39+
// netip.Addr's helper predicates do not cover. RFC 1918 (private),
40+
// loopback, link-local, multicast, and unspecified are handled by the
41+
// stdlib helpers in isBlockedAddr.
42+
var extraBlockedPrefixes = []netip.Prefix{
43+
netip.MustParsePrefix("0.0.0.0/8"), // RFC 1122 "this network"
44+
netip.MustParsePrefix("100.64.0.0/10"), // RFC 6598 CGNAT
45+
netip.MustParsePrefix("192.0.0.0/24"), // RFC 6890 IETF protocol assignments
46+
netip.MustParsePrefix("192.0.2.0/24"), // RFC 5737 TEST-NET-1
47+
netip.MustParsePrefix("198.18.0.0/15"), // RFC 2544 benchmarking
48+
netip.MustParsePrefix("198.51.100.0/24"), // RFC 5737 TEST-NET-2
49+
netip.MustParsePrefix("203.0.113.0/24"), // RFC 5737 TEST-NET-3
50+
netip.MustParsePrefix("240.0.0.0/4"), // RFC 1112 reserved (incl. 255.255.255.255)
51+
netip.MustParsePrefix("64:ff9b::/96"), // RFC 6052 IPv4/IPv6 translation
52+
netip.MustParsePrefix("64:ff9b:1::/48"), // RFC 8215 IPv4/IPv6 local translation
53+
netip.MustParsePrefix("100::/64"), // RFC 6666 discard prefix
54+
netip.MustParsePrefix("2001::/23"), // IETF protocol assignments
55+
netip.MustParsePrefix("2001:db8::/32"), // RFC 3849 documentation
56+
}
57+
58+
// makeSSRFDialControl returns a net.Dialer.Control function that
59+
// rejects connections to IP addresses in private, loopback,
60+
// link-local, multicast, CGNAT, or other reserved ranges. The check
61+
// runs after DNS resolution on the actual peer address, which
62+
// defeats DNS rebinding between any prior URL validation and the
63+
// dial.
64+
//
65+
// When allowLoopback is true, 127.0.0.0/8 and ::1 are permitted. The
66+
// option is intended for tests that bind httptest servers to the
67+
// loopback interface and must not be used in production.
68+
//
69+
// Only TCP and UDP networks are permitted; unix sockets and other
70+
// network types are refused outright.
71+
func makeSSRFDialControl(allowLoopback bool) func(network, address string, _ syscall.RawConn) error {
72+
return func(network, address string, _ syscall.RawConn) error {
73+
switch network {
74+
case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
75+
default:
76+
return fmt.Errorf("%w: refusing non-IP network %q", ErrBlockedAddress, network)
77+
}
78+
79+
addrPort, err := netip.ParseAddrPort(address)
80+
if err != nil {
81+
return fmt.Errorf("%w: cannot parse peer address %q: %v", ErrBlockedAddress, address, err)
82+
}
83+
84+
ip := addrPort.Addr()
85+
if ip.Is4In6() {
86+
ip = ip.Unmap()
87+
}
88+
89+
if allowLoopback && ip.IsLoopback() {
90+
return nil
91+
}
92+
93+
if isBlockedAddr(ip) {
94+
return fmt.Errorf("%w: %s", ErrBlockedAddress, ip)
95+
}
96+
97+
return nil
98+
}
99+
}
100+
101+
// isBlockedAddr reports whether ip falls in a range that should be
102+
// refused for outbound HTTP from an SSRF perspective.
103+
func isBlockedAddr(ip netip.Addr) bool {
104+
if !ip.IsValid() {
105+
return true
106+
}
107+
108+
// Unwrap IPv4-in-IPv6 (::ffff:a.b.c.d) so a single check covers
109+
// both representations.
110+
if ip.Is4In6() {
111+
ip = ip.Unmap()
112+
}
113+
114+
if ip.IsUnspecified() ||
115+
ip.IsLoopback() ||
116+
ip.IsPrivate() ||
117+
ip.IsLinkLocalUnicast() ||
118+
ip.IsLinkLocalMulticast() ||
119+
ip.IsMulticast() ||
120+
ip.IsInterfaceLocalMulticast() {
121+
return true
122+
}
123+
124+
for _, prefix := range extraBlockedPrefixes {
125+
if prefix.Contains(ip) {
126+
return true
127+
}
128+
}
129+
130+
return false
131+
}
132+
133+
// noCrossOriginRedirects is an http.Client.CheckRedirect function that
134+
// blocks redirects whose scheme, host, or port differs from the
135+
// original request. It is paired with WithSSRFProtection on the
136+
// stdlib clients returned by DefaultClient and DefaultPooledClient
137+
// to defeat redirect-based pivots into internal services.
138+
func noCrossOriginRedirects(req *http.Request, via []*http.Request) error {
139+
if len(via) == 0 {
140+
return nil
141+
}
142+
original := via[0].URL
143+
if !sameOrigin(original, req.URL) {
144+
return fmt.Errorf("%w: %s -> %s", ErrCrossOriginRedirect, originString(original), originString(req.URL))
145+
}
146+
return nil
147+
}
148+
149+
func sameOrigin(a, b *url.URL) bool {
150+
return strings.EqualFold(a.Scheme, b.Scheme) &&
151+
strings.EqualFold(a.Hostname(), b.Hostname()) &&
152+
defaultedPort(a) == defaultedPort(b)
153+
}
154+
155+
func originString(u *url.URL) string {
156+
return u.Scheme + "://" + u.Host
157+
}
158+
159+
func defaultedPort(u *url.URL) string {
160+
if p := u.Port(); p != "" {
161+
return p
162+
}
163+
switch strings.ToLower(u.Scheme) {
164+
case "https":
165+
return "443"
166+
case "http":
167+
return "80"
168+
}
169+
return ""
170+
}

0 commit comments

Comments
 (0)