Skip to content

Commit 06e3d67

Browse files
committed
feat: support dual-stack for interface binding
1 parent d27aef8 commit 06e3d67

File tree

1 file changed

+121
-56
lines changed

1 file changed

+121
-56
lines changed

doh-client/client.go

Lines changed: 121 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -92,24 +92,24 @@ func NewClient(conf *config.Config) (c *Client, err error) {
9292
}
9393

9494
if c.conf.Other.Interface != "" {
95-
// Setup UDP Dialer
96-
udpLocalAddr, err := c.bindToInterface("udp")
95+
localV4, localV6, err := c.getInterfaceIPs()
9796
if err != nil {
98-
return nil, fmt.Errorf("failed to bind passthrough UDP to interface %s: %v", c.conf.Other.Interface, err)
97+
return nil, fmt.Errorf("failed to get interface IPs for %s: %v", c.conf.Other.Interface, err)
9998
}
100-
c.udpClient.Dialer = &net.Dialer{
101-
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
102-
LocalAddr: udpLocalAddr,
99+
var localAddr net.IP
100+
if localV4 != nil {
101+
localAddr = localV4
102+
} else {
103+
localAddr = localV6
103104
}
104105

105-
// Setup TCP Dialer
106-
tcpLocalAddr, err := c.bindToInterface("tcp")
107-
if err != nil {
108-
return nil, fmt.Errorf("failed to bind passthrough TCP to interface %s: %v", c.conf.Other.Interface, err)
106+
c.udpClient.Dialer = &net.Dialer{
107+
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
108+
LocalAddr: &net.UDPAddr{IP: localAddr},
109109
}
110110
c.tcpClient.Dialer = &net.Dialer{
111111
Timeout: time.Duration(conf.Other.Timeout) * time.Second,
112-
LocalAddr: tcpLocalAddr,
112+
LocalAddr: &net.TCPAddr{IP: localAddr},
113113
}
114114
}
115115

@@ -144,11 +144,35 @@ func NewClient(conf *config.Config) (c *Client, err error) {
144144
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
145145
var d net.Dialer
146146
if c.conf.Other.Interface != "" {
147-
localAddr, err := c.bindToInterface(network)
147+
localV4, localV6, err := c.getInterfaceIPs()
148148
if err != nil {
149149
log.Printf("Bootstrap dial warning: %v", err)
150150
} else {
151-
d.LocalAddr = localAddr
151+
numServers := len(c.bootstrap)
152+
bootstrap := c.bootstrap[rand.Intn(numServers)]
153+
host, _, _ := net.SplitHostPort(bootstrap)
154+
ip := net.ParseIP(host)
155+
if ip != nil {
156+
if ip.To4() != nil {
157+
if localV4 != nil {
158+
if strings.HasPrefix(network, "udp") {
159+
d.LocalAddr = &net.UDPAddr{IP: localV4}
160+
} else {
161+
d.LocalAddr = &net.TCPAddr{IP: localV4}
162+
}
163+
}
164+
} else {
165+
if localV6 != nil {
166+
if strings.HasPrefix(network, "udp") {
167+
d.LocalAddr = &net.UDPAddr{IP: localV6}
168+
} else {
169+
d.LocalAddr = &net.TCPAddr{IP: localV6}
170+
}
171+
}
172+
}
173+
}
174+
conn, err := d.DialContext(ctx, network, bootstrap)
175+
return conn, err
152176
}
153177
}
154178
numServers := len(c.bootstrap)
@@ -266,22 +290,72 @@ func (c *Client) newHTTPClient() error {
266290
if c.httpTransport != nil {
267291
c.httpTransport.CloseIdleConnections()
268292
}
269-
dialer := &net.Dialer{
293+
294+
localV4, localV6, err := c.getInterfaceIPs()
295+
if err != nil {
296+
log.Printf("Interface binding error: %v", err)
297+
return err
298+
}
299+
300+
baseDialer := &net.Dialer{
270301
Timeout: time.Duration(c.conf.Other.Timeout) * time.Second,
271302
KeepAlive: 30 * time.Second,
272-
// DualStack: true,
273-
Resolver: c.bootstrapResolver,
274-
}
275-
if c.conf.Other.Interface != "" {
276-
localAddr, err := c.bindToInterface("tcp")
277-
if err != nil {
278-
log.Printf("Failed to resolve interface %s: %v", c.conf.Other.Interface, err)
279-
return err
280-
}
281-
dialer.LocalAddr = localAddr
303+
Resolver: c.bootstrapResolver,
282304
}
305+
283306
c.httpTransport = &http.Transport{
284-
DialContext: dialer.DialContext,
307+
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
308+
if c.conf.Other.Interface == "" {
309+
return baseDialer.DialContext(ctx, network, addr)
310+
}
311+
312+
if network == "tcp4" && localV4 != nil {
313+
d := *baseDialer
314+
d.LocalAddr = &net.TCPAddr{IP: localV4}
315+
return d.DialContext(ctx, network, addr)
316+
}
317+
if network == "tcp6" && localV6 != nil {
318+
d := *baseDialer
319+
d.LocalAddr = &net.TCPAddr{IP: localV6}
320+
return d.DialContext(ctx, network, addr)
321+
}
322+
323+
// Manual Dual-Stack: Resolve host and try compatible families sequentially
324+
host, port, _ := net.SplitHostPort(addr)
325+
ips, err := c.bootstrapResolver.LookupIPAddr(ctx, host)
326+
if err != nil {
327+
return nil, err
328+
}
329+
330+
var lastErr error
331+
for _, ip := range ips {
332+
d := *baseDialer
333+
targetAddr := net.JoinHostPort(ip.String(), port)
334+
335+
if ip.IP.To4() != nil {
336+
if localV4 == nil {
337+
continue
338+
}
339+
d.LocalAddr = &net.TCPAddr{IP: localV4}
340+
} else {
341+
if localV6 == nil {
342+
continue
343+
}
344+
d.LocalAddr = &net.TCPAddr{IP: localV6}
345+
}
346+
347+
conn, err := d.DialContext(ctx, "tcp", targetAddr)
348+
if err == nil {
349+
return conn, nil
350+
}
351+
lastErr = err
352+
}
353+
354+
if lastErr != nil {
355+
return nil, lastErr
356+
}
357+
return nil, fmt.Errorf("connection to %s failed: no matching local/remote IP families on interface %s", addr, c.conf.Other.Interface)
358+
},
285359
ExpectContinueTimeout: 1 * time.Second,
286360
IdleConnTimeout: 90 * time.Second,
287361
MaxIdleConns: 100,
@@ -290,15 +364,18 @@ func (c *Client) newHTTPClient() error {
290364
TLSHandshakeTimeout: time.Duration(c.conf.Other.Timeout) * time.Second,
291365
TLSClientConfig: &tls.Config{InsecureSkipVerify: c.conf.Other.TLSInsecureSkipVerify},
292366
}
367+
293368
if c.conf.Other.NoIPv6 {
369+
originalDial := c.httpTransport.DialContext
294370
c.httpTransport.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) {
295371
if strings.HasPrefix(network, "tcp") {
296372
network = "tcp4"
297373
}
298-
return dialer.DialContext(ctx, network, address)
374+
return originalDial(ctx, network, address)
299375
}
300376
}
301-
err := http2.ConfigureTransport(c.httpTransport)
377+
378+
err = http2.ConfigureTransport(c.httpTransport)
302379
if err != nil {
303380
return err
304381
}
@@ -525,49 +602,37 @@ func (c *Client) findClientIP(w dns.ResponseWriter, r *dns.Msg) (ednsClientAddre
525602
return
526603
}
527604

528-
func (c *Client) bindToInterface(network string) (net.Addr, error) {
605+
// getInterfaceIPs returns the first valid IPv4 and IPv6 addresses found on the interface
606+
func (c *Client) getInterfaceIPs() (v4, v6 net.IP, err error) {
529607
if c.conf.Other.Interface == "" {
530-
return nil, nil
608+
return nil, nil, nil
531609
}
532610
ifi, err := net.InterfaceByName(c.conf.Other.Interface)
533611
if err != nil {
534-
return nil, err
612+
return nil, nil, err
535613
}
536614
addrs, err := ifi.Addrs()
537615
if err != nil {
538-
return nil, err
616+
return nil, nil, err
539617
}
540618

541-
// Determine if we need IPv4 or IPv6 based on the network string (e.g., "tcp4", "udp6")
542-
wantIPv6 := strings.Contains(network, "6")
543-
wantIPv4 := strings.Contains(network, "4") || !wantIPv6 // Default to 4 if not specified, or if generic "tcp"/"udp"
544-
545619
for _, addr := range addrs {
546620
ip, _, err := net.ParseCIDR(addr.String())
547621
if err != nil {
548622
continue
549623
}
550-
551-
// Skip if we want IPv4 but got IPv6
552-
if ip.To4() == nil && wantIPv4 && !wantIPv6 {
553-
continue
554-
}
555-
// Skip if we want IPv6 but got IPv4
556-
if ip.To4() != nil && wantIPv6 {
557-
continue
558-
}
559-
// Skip IPv6 if disabled in config
560-
if ip.To4() == nil && c.conf.Other.NoIPv6 {
561-
continue
562-
}
563-
564-
// Return the appropriate address type
565-
if strings.HasPrefix(network, "tcp") {
566-
return &net.TCPAddr{IP: ip}, nil
567-
}
568-
if strings.HasPrefix(network, "udp") {
569-
return &net.UDPAddr{IP: ip}, nil
624+
if ip4 := ip.To4(); ip4 != nil {
625+
if v4 == nil {
626+
v4 = ip4
627+
}
628+
} else {
629+
if v6 == nil && !c.conf.Other.NoIPv6 {
630+
v6 = ip
631+
}
570632
}
571633
}
572-
return nil, fmt.Errorf("no suitable address found on interface %s for network %s", c.conf.Other.Interface, network)
634+
if v4 == nil && v6 == nil {
635+
return nil, nil, fmt.Errorf("no valid IP addresses found on interface %s", c.conf.Other.Interface)
636+
}
637+
return v4, v6, nil
573638
}

0 commit comments

Comments
 (0)