Skip to content

Commit 98f727d

Browse files
authored
fix(storage): enforce request-time endpoint guard (#477)
1 parent d799277 commit 98f727d

8 files changed

Lines changed: 369 additions & 34 deletions

File tree

cmd/bao-backup/storage_config.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ func buildStorageConfig(cfg *backupconfig.ExecutorConfig) (storage.Config, error
2222
}
2323

2424
storageConfig := storage.Config{
25-
Provider: storage.ProviderType(provider),
26-
Bucket: cfg.BackupBucket,
27-
Endpoint: cfg.BackupEndpoint,
25+
Provider: storage.ProviderType(provider),
26+
Bucket: cfg.BackupBucket,
27+
Endpoint: cfg.BackupEndpoint,
28+
ValidateEndpointRequests: true,
2829
}
2930

3031
switch provider {

internal/adapter/storage/azure.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ type AzureClientConfig struct {
4444
CACert []byte
4545
// EnsureExists optionally creates the container if it does not exist.
4646
EnsureExists bool
47+
// ValidateEndpointRequests rejects request-time redirects or DNS results to local or metadata-adjacent destinations.
48+
ValidateEndpointRequests bool
4749
}
4850

4951
// OpenAzureContainer opens an Azure blob container using Go CDK.
@@ -118,14 +120,24 @@ func OpenAzureContainer(ctx context.Context, cfg AzureClientConfig) (blobstore.B
118120
}
119121

120122
var cred azcore.TokenCredential
123+
credentialHTTPClient, err := buildAzureHTTPClient(AzureClientConfig{
124+
InsecureSkipVerify: cfg.InsecureSkipVerify,
125+
CACert: cfg.CACert,
126+
})
127+
if err != nil {
128+
return nil, fmt.Errorf("failed to build Azure credential HTTP client: %w", err)
129+
}
130+
credentialClientOptions := azcore.ClientOptions{
131+
Transport: credentialHTTPClient,
132+
}
121133
if cfg.ManagedIdentityClientID != "" {
122134
cred, err = azidentity.NewManagedIdentityCredential(&azidentity.ManagedIdentityCredentialOptions{
123135
ID: azidentity.ClientID(cfg.ManagedIdentityClientID),
124-
ClientOptions: clientOpts.ClientOptions,
136+
ClientOptions: credentialClientOptions,
125137
})
126138
} else {
127139
cred, err = azidentity.NewDefaultAzureCredential(&azidentity.DefaultAzureCredentialOptions{
128-
ClientOptions: clientOpts.ClientOptions,
140+
ClientOptions: credentialClientOptions,
129141
})
130142
}
131143
if err != nil {
@@ -185,10 +197,12 @@ func buildAzureHTTPClient(cfg AzureClientConfig) (*http.Client, error) {
185197
MinVersion: tls.VersionTLS12,
186198
}
187199

188-
return &http.Client{
200+
client := &http.Client{
189201
Transport: transport,
190202
Timeout: DefaultUploadTimeout,
191-
}, nil
203+
}
204+
applyStorageEndpointRequestGuard(client, transport, cfg.ValidateEndpointRequests)
205+
return client, nil
192206
}
193207

194208
func ensureAzureContainer(ctx context.Context, c *container.Client) error {
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package storage
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net"
8+
"net/http"
9+
"net/netip"
10+
"net/url"
11+
"regexp"
12+
"strings"
13+
"time"
14+
)
15+
16+
type storageEndpointRequestResolver interface {
17+
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
18+
}
19+
20+
type storageEndpointRequestGuard struct {
21+
resolver storageEndpointRequestResolver
22+
dialContext func(ctx context.Context, network, address string) (net.Conn, error)
23+
}
24+
25+
var storageAmbiguousNumericHostPattern = regexp.MustCompile(`(?i)^(0x[0-9a-f]+|[0-9]+)(\.(0x[0-9a-f]+|[0-9]+)){0,3}$`)
26+
27+
func applyStorageEndpointRequestGuard(client *http.Client, transport *http.Transport, enabled bool) {
28+
if !enabled {
29+
return
30+
}
31+
guard := newStorageEndpointRequestGuard()
32+
transport.DialContext = guard.guardedDialContext
33+
client.CheckRedirect = guard.checkRedirect
34+
}
35+
36+
func newStorageEndpointRequestGuard() storageEndpointRequestGuard {
37+
dialer := &net.Dialer{
38+
Timeout: 30 * time.Second,
39+
KeepAlive: 30 * time.Second,
40+
}
41+
return storageEndpointRequestGuard{
42+
resolver: net.DefaultResolver,
43+
dialContext: dialer.DialContext,
44+
}
45+
}
46+
47+
func (g storageEndpointRequestGuard) guardedDialContext(ctx context.Context, network, address string) (net.Conn, error) {
48+
host, port, err := net.SplitHostPort(address)
49+
if err != nil {
50+
return nil, fmt.Errorf("storage request destination %q is invalid: %w", address, err)
51+
}
52+
53+
addrs, err := g.resolveAllowedAddrs(ctx, network, host)
54+
if err != nil {
55+
return nil, err
56+
}
57+
58+
dialErrs := make([]error, 0, len(addrs))
59+
for _, addr := range addrs {
60+
if !networkAllowsEndpointAddress(network, addr) {
61+
continue
62+
}
63+
conn, err := g.dialContext(ctx, network, net.JoinHostPort(addr.String(), port))
64+
if err == nil {
65+
return conn, nil
66+
}
67+
dialErrs = append(dialErrs, err)
68+
}
69+
70+
if len(dialErrs) == 0 {
71+
return nil, fmt.Errorf("storage request destination %q has no addresses compatible with network %q", address, network)
72+
}
73+
return nil, fmt.Errorf("storage request destination %q could not be reached: %w", address, errors.Join(dialErrs...))
74+
}
75+
76+
func (g storageEndpointRequestGuard) checkRedirect(req *http.Request, via []*http.Request) error {
77+
if len(via) >= 10 {
78+
return fmt.Errorf("stopped after 10 redirects")
79+
}
80+
return g.validateURL(req.Context(), req.URL)
81+
}
82+
83+
func (g storageEndpointRequestGuard) validateURL(ctx context.Context, u *url.URL) error {
84+
if u == nil {
85+
return fmt.Errorf("storage request redirect target is missing")
86+
}
87+
host := normalizeStorageEndpointRequestHost(u.Hostname())
88+
if host == "" {
89+
return fmt.Errorf("storage request redirect target must include a host")
90+
}
91+
_, err := g.resolveAllowedAddrs(ctx, "ip", host)
92+
return err
93+
}
94+
95+
func (g storageEndpointRequestGuard) resolveAllowedAddrs(ctx context.Context, network, host string) ([]netip.Addr, error) {
96+
host = normalizeStorageEndpointRequestHost(host)
97+
if host == "" {
98+
return nil, fmt.Errorf("storage request destination must include a host")
99+
}
100+
if host == "localhost" || strings.HasSuffix(host, ".localhost") {
101+
return nil, fmt.Errorf("storage request destination host %q is not allowed", host)
102+
}
103+
if addr, err := netip.ParseAddr(host); err == nil {
104+
addr = addr.Unmap()
105+
if isForbiddenStorageEndpointRequestAddress(addr) {
106+
return nil, fmt.Errorf("storage request destination host %q is not allowed", host)
107+
}
108+
return []netip.Addr{addr}, nil
109+
}
110+
if storageAmbiguousNumericHostPattern.MatchString(host) {
111+
return nil, fmt.Errorf("storage request destination host %q uses ambiguous numeric IP encoding", host)
112+
}
113+
if strings.Contains(host, ":") {
114+
return nil, fmt.Errorf("storage request destination host %q uses ambiguous IP encoding", host)
115+
}
116+
if g.resolver == nil {
117+
return nil, fmt.Errorf("storage request destination resolver is required")
118+
}
119+
120+
addrs, err := g.resolver.LookupNetIP(ctx, resolverNetworkForEndpointDial(network), host)
121+
if err != nil {
122+
return nil, fmt.Errorf("storage request destination host %q could not be resolved: %w", host, err)
123+
}
124+
if len(addrs) == 0 {
125+
return nil, fmt.Errorf("storage request destination host %q did not resolve to any IP addresses", host)
126+
}
127+
for _, addr := range addrs {
128+
if isForbiddenStorageEndpointRequestAddress(addr) {
129+
return nil, fmt.Errorf("storage request destination host %q resolves to forbidden address %s", host, addr)
130+
}
131+
}
132+
return addrs, nil
133+
}
134+
135+
func normalizeStorageEndpointRequestHost(host string) string {
136+
host = strings.ToLower(strings.TrimSpace(host))
137+
host = strings.TrimSuffix(host, ".")
138+
return host
139+
}
140+
141+
func isForbiddenStorageEndpointRequestAddress(addr netip.Addr) bool {
142+
addr = addr.Unmap()
143+
return !addr.IsValid() ||
144+
addr.IsLoopback() ||
145+
addr.IsLinkLocalUnicast() ||
146+
addr.IsUnspecified() ||
147+
addr.IsMulticast()
148+
}
149+
150+
func networkAllowsEndpointAddress(network string, addr netip.Addr) bool {
151+
addr = addr.Unmap()
152+
switch {
153+
case strings.HasSuffix(network, "4"):
154+
return addr.Is4()
155+
case strings.HasSuffix(network, "6"):
156+
return addr.Is6()
157+
default:
158+
return true
159+
}
160+
}
161+
162+
func resolverNetworkForEndpointDial(network string) string {
163+
switch {
164+
case strings.HasSuffix(network, "4"):
165+
return "ip4"
166+
case strings.HasSuffix(network, "6"):
167+
return "ip6"
168+
default:
169+
return "ip"
170+
}
171+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package storage
2+
3+
import (
4+
"context"
5+
"net"
6+
"net/http"
7+
"net/netip"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
type fakeStorageEndpointRequestResolver map[string][]netip.Addr
15+
16+
func (r fakeStorageEndpointRequestResolver) LookupNetIP(_ context.Context, _ string, host string) ([]netip.Addr, error) {
17+
if addrs, ok := r[host]; ok {
18+
return addrs, nil
19+
}
20+
return nil, &net.DNSError{
21+
Err: "no such host",
22+
Name: host,
23+
IsNotFound: true,
24+
}
25+
}
26+
27+
type recordingStorageEndpointRequestResolver struct {
28+
network string
29+
addrs []netip.Addr
30+
}
31+
32+
func (r *recordingStorageEndpointRequestResolver) LookupNetIP(_ context.Context, network, _ string) ([]netip.Addr, error) {
33+
r.network = network
34+
return r.addrs, nil
35+
}
36+
37+
func TestStorageEndpointRequestGuardDialsResolvedAddress(t *testing.T) {
38+
t.Parallel()
39+
40+
var dialedAddress string
41+
resolver := &recordingStorageEndpointRequestResolver{
42+
addrs: []netip.Addr{netip.MustParseAddr("203.0.113.10")},
43+
}
44+
guard := storageEndpointRequestGuard{
45+
resolver: resolver,
46+
dialContext: func(_ context.Context, _ string, address string) (net.Conn, error) {
47+
dialedAddress = address
48+
clientConn, serverConn := net.Pipe()
49+
t.Cleanup(func() {
50+
_ = clientConn.Close()
51+
_ = serverConn.Close()
52+
})
53+
return clientConn, nil
54+
},
55+
}
56+
57+
conn, err := guard.guardedDialContext(context.Background(), "tcp4", "storage.example.test:443")
58+
require.NoError(t, err)
59+
require.NotNil(t, conn)
60+
assert.Equal(t, "ip4", resolver.network)
61+
assert.Equal(t, "203.0.113.10:443", dialedAddress)
62+
}
63+
64+
func TestStorageEndpointRequestGuardAllowsPrivateClusterAddress(t *testing.T) {
65+
t.Parallel()
66+
67+
var dialedAddress string
68+
guard := storageEndpointRequestGuard{
69+
resolver: fakeStorageEndpointRequestResolver{},
70+
dialContext: func(_ context.Context, _ string, address string) (net.Conn, error) {
71+
dialedAddress = address
72+
clientConn, serverConn := net.Pipe()
73+
t.Cleanup(func() {
74+
_ = clientConn.Close()
75+
_ = serverConn.Close()
76+
})
77+
return clientConn, nil
78+
},
79+
}
80+
81+
conn, err := guard.guardedDialContext(context.Background(), "tcp", "10.96.12.34:9000")
82+
require.NoError(t, err)
83+
require.NotNil(t, conn)
84+
assert.Equal(t, "10.96.12.34:9000", dialedAddress)
85+
}
86+
87+
func TestStorageEndpointRequestGuardRejectsForbiddenDNSAtDialTime(t *testing.T) {
88+
t.Parallel()
89+
90+
dialed := false
91+
guard := storageEndpointRequestGuard{
92+
resolver: fakeStorageEndpointRequestResolver{
93+
"storage.example.test": []netip.Addr{netip.MustParseAddr("169.254.169.254")},
94+
},
95+
dialContext: func(context.Context, string, string) (net.Conn, error) {
96+
dialed = true
97+
return nil, nil
98+
},
99+
}
100+
101+
conn, err := guard.guardedDialContext(context.Background(), "tcp", "storage.example.test:443")
102+
require.Error(t, err)
103+
assert.Nil(t, conn)
104+
assert.False(t, dialed)
105+
assert.Contains(t, err.Error(), "resolves to forbidden address")
106+
}
107+
108+
func TestStorageEndpointRequestGuardRejectsRedirectToForbiddenHost(t *testing.T) {
109+
t.Parallel()
110+
111+
client, err := buildHTTPClient(nil, false, true)
112+
require.NoError(t, err)
113+
require.NotNil(t, client.CheckRedirect)
114+
115+
req, err := http.NewRequest(http.MethodGet, "http://169.254.169.254/latest/meta-data", nil)
116+
require.NoError(t, err)
117+
118+
err = client.CheckRedirect(req, nil)
119+
require.Error(t, err)
120+
assert.Contains(t, err.Error(), "not allowed")
121+
}

internal/adapter/storage/gcs.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ type GCSClientConfig struct {
4343
CACert []byte
4444
// EnsureExists optionally validates/creates the bucket (requires Project for create).
4545
EnsureExists bool
46+
// ValidateEndpointRequests rejects request-time redirects or DNS results to local or metadata-adjacent destinations.
47+
ValidateEndpointRequests bool
4648
}
4749

4850
// OpenGCSBucket opens a GCS bucket using Go CDK.
@@ -172,10 +174,12 @@ func buildGCSHTTPClient(cfg GCSClientConfig) (*http.Client, error) {
172174
}
173175
transport.TLSClientConfig = tlsConfig
174176

175-
return &http.Client{
177+
client := &http.Client{
176178
Transport: transport,
177179
Timeout: DefaultUploadTimeout,
178-
}, nil
180+
}
181+
applyStorageEndpointRequestGuard(client, transport, cfg.ValidateEndpointRequests)
182+
return client, nil
179183
}
180184

181185
func ensureGCSBucket(ctx context.Context, cfg GCSClientConfig, b *Bucket) error {

0 commit comments

Comments
 (0)