Skip to content

Commit 605b152

Browse files
committed
modified the customhttp package to address an issue with:
DialContext field on ClientConfig Access to the underlying net.Conn after connection
1 parent c2e95ef commit 605b152

3 files changed

Lines changed: 99 additions & 16 deletions

File tree

pkg/customhttp/client.go

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,69 @@ func (c *CustomClient) ConnectionCount() int {
106106
return len(c.h1Clients) + len(c.h2Clients) + len(c.h3Clients)
107107
}
108108

109+
// GetH1Connection returns the underlying net.Conn for an H1 connection to the specified host.
110+
// The host should be in "host:port" format (e.g., "example.com:443").
111+
// Returns nil if no H1 connection exists for the host or if the client is not connected.
112+
// This method is thread-safe.
113+
func (c *CustomClient) GetH1Connection(host string) net.Conn {
114+
c.mu.RLock()
115+
client, exists := c.h1Clients[host]
116+
c.mu.RUnlock()
117+
118+
if !exists || client == nil {
119+
return nil
120+
}
121+
return client.Conn
122+
}
123+
124+
// GetH2Connection returns the underlying net.Conn for an H2 connection to the specified host.
125+
// The host should be in "host:port" format (e.g., "example.com:443").
126+
// Returns nil if no H2 connection exists for the host or if the client is not connected.
127+
// This method is thread-safe.
128+
func (c *CustomClient) GetH2Connection(host string) net.Conn {
129+
c.mu.RLock()
130+
client, exists := c.h2Clients[host]
131+
c.mu.RUnlock()
132+
133+
if !exists || client == nil {
134+
return nil
135+
}
136+
return client.Conn
137+
}
138+
139+
// GetH3Connection returns the underlying net.Conn for an H3 connection to the specified host.
140+
// Note: HTTP/3 uses QUIC which manages connections differently. The underlying QUIC
141+
// connection is managed by the http3.Transport and is not directly exposed as a net.Conn.
142+
// This method always returns nil for H3 connections.
143+
// This method is thread-safe.
144+
func (c *CustomClient) GetH3Connection(host string) net.Conn {
145+
// H3 uses QUIC which doesn't expose a traditional net.Conn.
146+
// The connection is managed internally by quic-go's http3.Transport.
147+
return nil
148+
}
149+
150+
// GetConnection returns the underlying net.Conn for any active connection to the specified host.
151+
// It checks H2 and H1 connections in that order and returns the first one found.
152+
// The host should be in "host:port" format (e.g., "example.com:443").
153+
// Returns nil if no connection exists for the host.
154+
// Note: H3/QUIC connections are not included as they don't expose a traditional net.Conn.
155+
// This method is thread-safe.
156+
func (c *CustomClient) GetConnection(host string) net.Conn {
157+
c.mu.RLock()
158+
defer c.mu.RUnlock()
159+
160+
// Check H2 first (most common for modern HTTPS)
161+
if client, exists := c.h2Clients[host]; exists && client != nil && client.Conn != nil {
162+
return client.Conn
163+
}
164+
// Check H1
165+
if client, exists := c.h1Clients[host]; exists && client != nil && client.Conn != nil {
166+
return client.Conn
167+
}
168+
// H3/QUIC connections don't expose a traditional net.Conn
169+
return nil
170+
}
171+
109172
// Do executes a single HTTP request and returns the response. This is the main
110173
// entry point for the client. It orchestrates the entire request lifecycle,
111174
// including:
@@ -853,17 +916,19 @@ func isSameOrigin(u1, u2 *url.URL) bool {
853916
}
854917
p1 := u1.Port()
855918
if p1 == "" {
856-
if u1.Scheme == "https" {
919+
switch u1.Scheme {
920+
case "https":
857921
p1 = "443"
858-
} else if u1.Scheme == "http" {
922+
case "http":
859923
p1 = "80"
860924
}
861925
}
862926
p2 := u2.Port()
863927
if p2 == "" {
864-
if u2.Scheme == "https" {
928+
switch u2.Scheme {
929+
case "https":
865930
p2 = "443"
866-
} else if u2.Scheme == "http" {
931+
case "http":
867932
p2 = "80"
868933
}
869934
}

pkg/customhttp/h2client.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1175,7 +1175,7 @@ func (c *H2Client) processHeadersFrame(stream *h2StreamState, f *http2.MetaHeade
11751175
stream.StatusCode = status
11761176
}
11771177
} else {
1178-
stream.Headers.Add(http.CanonicalHeaderKey(hf.Name), hf.Value)
1178+
stream.Headers.Add(hf.Name, hf.Value)
11791179
}
11801180
}
11811181

pkg/network/dialer.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ import (
1414
"time"
1515
)
1616

17+
// DialContextFunc is a function type that matches the signature of net.Dialer.DialContext.
18+
// It can be used to provide custom connection establishment logic.
19+
type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error)
20+
1721
// DialerConfig holds the complete configuration for establishing network connections,
1822
// including timeouts, keep-alive settings, TLS parameters, TCP options, custom DNS
1923
// resolvers, and proxy settings.
@@ -34,6 +38,11 @@ type DialerConfig struct {
3438
// ProxyURL specifies the URL of an HTTP/HTTPS proxy server. If set, connections
3539
// will be established through the proxy using the CONNECT method.
3640
ProxyURL *url.URL
41+
// DialContext allows providing a custom dial function for establishing TCP connections.
42+
// If set, this function will be used instead of the default net.Dialer for direct connections.
43+
// Note: When ProxyURL is set, this function is used to connect to the proxy server.
44+
// The function should return a raw TCP connection; TLS wrapping is handled separately.
45+
DialContext DialContextFunc
3746
}
3847

3948
// Clone creates a deep copy of the DialerConfig, ensuring that mutable fields
@@ -54,6 +63,8 @@ func (c *DialerConfig) Clone() *DialerConfig {
5463
proxyURLCopy := *c.ProxyURL
5564
clone.ProxyURL = &proxyURLCopy
5665
}
66+
// DialContext is a function, safe to copy by value (shallow copy).
67+
// The caller is responsible for ensuring their custom DialContext is thread-safe.
5768
return &clone
5869
}
5970

@@ -127,16 +138,23 @@ func DialTCPContext(ctx context.Context, network, address string, config *Dialer
127138

128139
// dialDirect establishes a direct TCP connection to the target address.
129140
func dialDirect(ctx context.Context, network, address string, config *DialerConfig) (net.Conn, error) {
130-
dialer := &net.Dialer{
131-
Timeout: config.Timeout,
132-
KeepAlive: config.KeepAlive,
133-
// Enable Happy Eyeballs (RFC 8305) for faster IPv4/IPv6 fallback.
134-
FallbackDelay: 300 * time.Millisecond,
135-
Resolver: config.Resolver,
136-
}
141+
var rawConn net.Conn
142+
var err error
137143

138144
// Step 1: Establish the raw TCP connection.
139-
rawConn, err := dialer.DialContext(ctx, network, address)
145+
// Use custom DialContext if provided, otherwise use the default net.Dialer.
146+
if config.DialContext != nil {
147+
rawConn, err = config.DialContext(ctx, network, address)
148+
} else {
149+
dialer := &net.Dialer{
150+
Timeout: config.Timeout,
151+
KeepAlive: config.KeepAlive,
152+
// Enable Happy Eyeballs (RFC 8305) for faster IPv4/IPv6 fallback.
153+
FallbackDelay: 300 * time.Millisecond,
154+
Resolver: config.Resolver,
155+
}
156+
rawConn, err = dialer.DialContext(ctx, network, address)
157+
}
140158
if err != nil {
141159
return nil, fmt.Errorf("tcp dial failed: %w", err)
142160
}
@@ -276,9 +294,9 @@ func (c *prefixedConn) Read(p []byte) (int, error) {
276294
// both direct and proxied connections transparently.
277295
//
278296
// This function orchestrates the entire connection process:
279-
// 1. Establishes a raw TCP connection (or a proxy tunnel) via DialTCPContext.
280-
// 2. If `config.TLSConfig` is not nil, it performs a TLS handshake over the
281-
// established connection to upgrade it to a secure `tls.Conn`.
297+
// 1. Establishes a raw TCP connection (or a proxy tunnel) via DialTCPContext.
298+
// 2. If `config.TLSConfig` is not nil, it performs a TLS handshake over the
299+
// established connection to upgrade it to a secure `tls.Conn`.
282300
//
283301
// It is suitable for creating connections for protocols that handle their own
284302
// application layer logic, such as WebSockets.

0 commit comments

Comments
 (0)