Skip to content

Commit e66c989

Browse files
author
michel-laterman
committed
Add support for ProxyConnectHeader in the dialer.
Add the ability to pass ProxyConnectHeader to the dialer. This set of headers will be used when a CONNECT request is made to an http(s) proxy.
1 parent e064f32 commit e66c989

3 files changed

Lines changed: 58 additions & 6 deletions

File tree

client.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ type Dialer struct {
8787
// If Proxy is nil or returns a nil *URL, no proxy is used.
8888
Proxy func(*http.Request) (*url.URL, error)
8989

90+
// ProxyConnectHeader specifies optional headers to use during proxy connect requests.
91+
ProxyConnectHeader http.Header
92+
9093
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
9194
// If nil, the default configuration is used.
9295
// If NetDialTLSContext is set, Dial assumes the TLS handshake
@@ -416,7 +419,7 @@ func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *u
416419
}
417420
// Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth.
418421
if proxyURL != nil {
419-
return proxyFromURL(proxyURL, netDial)
422+
return proxyFromURL(proxyURL, netDial, d.ProxyConnectHeader)
420423
}
421424
return netDial, nil
422425
}

client_server_test.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,46 @@ func TestProxyAuthorizationDial(t *testing.T) {
242242
sendRecv(t, ws)
243243
}
244244

245+
func TestProxyDialConnectHeaders(t *testing.T) {
246+
s := newServer(t)
247+
defer s.Close()
248+
249+
surl, _ := url.Parse(s.Server.URL)
250+
251+
cstDialer := cstDialer // make local copy for modification on next line.
252+
cstDialer.Proxy = http.ProxyURL(surl)
253+
cstDialer.ProxyConnectHeader = http.Header{"User-Agent": []string{"test-proxy-agent"}}
254+
255+
connect := false
256+
origHandler := s.Server.Config.Handler
257+
258+
// Capture the request Host header.
259+
s.Server.Config.Handler = http.HandlerFunc(
260+
func(w http.ResponseWriter, r *http.Request) {
261+
t.Logf("Request headers: %v", r.Header)
262+
userAgent := r.Header.Get("User-Agent")
263+
if r.Method == http.MethodConnect && userAgent == "test-proxy-agent" {
264+
connect = true
265+
w.WriteHeader(http.StatusOK)
266+
return
267+
}
268+
269+
if !connect {
270+
t.Log("connect with proxy connect headers not received")
271+
http.Error(w, "connect with proxy connect headers not received", http.StatusMethodNotAllowed)
272+
return
273+
}
274+
origHandler.ServeHTTP(w, r)
275+
})
276+
277+
ws, _, err := cstDialer.Dial(s.URL, nil)
278+
if err != nil {
279+
t.Fatalf("Dial: %v", err)
280+
}
281+
defer ws.Close()
282+
sendRecv(t, ws)
283+
}
284+
245285
func TestDial(t *testing.T) {
246286
s := newServer(t)
247287
defer s.Close()

proxy.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) (
2828
return fn(ctx, network, addr)
2929
}
3030

31-
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) {
31+
func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc, connectHeader http.Header) (netDialerFunc, error) {
3232
if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
33-
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil
33+
return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial, proxyConnectHeader: connectHeader}).DialContext, nil
3434
}
3535
dialer, err := proxy.FromURL(proxyURL, forwardDial)
3636
if err != nil {
@@ -45,8 +45,13 @@ func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc,
4545
}
4646

4747
type httpProxyDialer struct {
48-
proxyURL *url.URL
49-
forwardDial netDialerFunc
48+
proxyURL *url.URL
49+
forwardDial netDialerFunc
50+
proxyConnectHeader http.Header
51+
}
52+
53+
func (hpd *httpProxyDialer) Dial(network, addr string) (net.Conn, error) {
54+
return hpd.DialContext(context.Background(), network, addr)
5055
}
5156

5257
func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) {
@@ -56,7 +61,11 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add
5661
return nil, err
5762
}
5863

59-
connectHeader := make(http.Header)
64+
connectHeader := hpd.proxyConnectHeader
65+
if hpd.proxyConnectHeader == nil {
66+
connectHeader = make(http.Header)
67+
}
68+
6069
if user := hpd.proxyURL.User; user != nil {
6170
proxyUser := user.Username()
6271
if proxyPassword, passwordSet := user.Password(); passwordSet {

0 commit comments

Comments
 (0)