Skip to content

Commit c91f439

Browse files
committed
WireGuard outbound: Enhance DNS resolution with family-specific lookups and caching
1 parent 3d53988 commit c91f439

2 files changed

Lines changed: 117 additions & 11 deletions

File tree

proxy/wireguard/client.go

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -143,23 +143,39 @@ func (h *Handler) processWireGuard(ctx context.Context, dialer internet.Dialer)
143143
return nil
144144
}
145145

146+
func (h *Handler) lookupIPWithOption(domain string, option dns.IPOption) (net.Address, error) {
147+
ips, _, err := h.dns.LookupIP(domain, option)
148+
if err != nil {
149+
return nil, err
150+
} else if len(ips) == 0 {
151+
return nil, dns.ErrEmptyResponse
152+
}
153+
return net.IPAddress(ips[dice.Roll(len(ips))]), nil
154+
}
155+
146156
func (h *Handler) lookupIP(domain string) (net.Address, error) {
147-
ips, _, err := h.dns.LookupIP(domain, dns.IPOption{
157+
addr, err := h.lookupIPWithOption(domain, dns.IPOption{
148158
IPv4Enable: h.hasIPv4 && h.conf.preferIP4(),
149159
IPv6Enable: h.hasIPv6 && h.conf.preferIP6(),
150160
})
151-
if (len(ips) == 0 || err != nil) && h.conf.hasFallback() {
152-
ips, _, err = h.dns.LookupIP(domain, dns.IPOption{
161+
if err != nil && h.conf.hasFallback() {
162+
return h.lookupIPWithOption(domain, dns.IPOption{
153163
IPv4Enable: h.hasIPv4 && h.conf.fallbackIP4(),
154164
IPv6Enable: h.hasIPv6 && h.conf.fallbackIP6(),
155165
})
156166
}
157-
if err != nil {
158-
return nil, err
159-
} else if len(ips) == 0 {
160-
return nil, dns.ErrEmptyResponse
167+
return addr, err
168+
}
169+
170+
func (h *Handler) lookupIPForFamily(domain string, family net.AddressFamily) (net.Address, error) {
171+
switch family {
172+
case net.AddressFamilyIPv4:
173+
return h.lookupIPWithOption(domain, dns.IPOption{IPv4Enable: true})
174+
case net.AddressFamilyIPv6:
175+
return h.lookupIPWithOption(domain, dns.IPOption{IPv6Enable: true})
176+
default:
177+
return h.lookupIP(domain)
161178
}
162-
return net.IPAddress(ips[dice.Roll(len(ips))]), nil
163179
}
164180

165181
// Process implements OutboundHandler.Dispatch().
@@ -185,8 +201,10 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
185201

186202
// resolve dns
187203
addr := destination.Address
204+
destinationDomain := ""
188205
if addr.Family().IsDomain() {
189-
resolved, err := h.lookupIP(addr.Domain())
206+
destinationDomain = addr.Domain()
207+
resolved, err := h.lookupIP(destinationDomain)
190208
if err != nil {
191209
return errors.New("failed to lookup DNS").Base(err)
192210
}
@@ -236,11 +254,18 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
236254
}
237255
defer conn.Close()
238256

257+
resolvedUDPAddr := utils.NewTypedSyncMap[string, net.Address]()
258+
if destinationDomain != "" {
259+
resolvedUDPAddr.Store(strings.ToLower(destinationDomain), addr)
260+
}
261+
udpLookupIP := func(domain string) (net.Address, error) {
262+
return h.lookupIPForFamily(domain, addr.Family())
263+
}
239264
conn = &udpConnClient{
240265
Conn: conn,
241266
dest: destination,
242-
lookupIP: h.lookupIP,
243-
resolvedUDPAddr: utils.NewTypedSyncMap[string, net.Address](),
267+
lookupIP: udpLookupIP,
268+
resolvedUDPAddr: resolvedUDPAddr,
244269
}
245270

246271
requestFunc = func() error {

proxy/wireguard/client_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/xtls/xray-core/common/buf"
1010
"github.com/xtls/xray-core/common/net"
1111
"github.com/xtls/xray-core/common/utils"
12+
feature_dns "github.com/xtls/xray-core/features/dns"
1213
"github.com/xtls/xray-core/transport/internet"
1314
)
1415

@@ -45,6 +46,33 @@ func (c *capturePacketConn) SetWriteDeadline(time.Time) error {
4546
return nil
4647
}
4748

49+
type recordingDNSClient struct {
50+
options []feature_dns.IPOption
51+
}
52+
53+
func (c *recordingDNSClient) Type() interface{} {
54+
return feature_dns.ClientType()
55+
}
56+
57+
func (c *recordingDNSClient) Start() error {
58+
return nil
59+
}
60+
61+
func (c *recordingDNSClient) Close() error {
62+
return nil
63+
}
64+
65+
func (c *recordingDNSClient) LookupIP(_ string, option feature_dns.IPOption) ([]net.IP, uint32, error) {
66+
c.options = append(c.options, option)
67+
if option.IPv4Enable {
68+
return []net.IP{gonet.IPv4(162, 159, 36, 1)}, 0, nil
69+
}
70+
if option.IPv6Enable {
71+
return []net.IP{gonet.ParseIP("2606:4700:4700::1111")}, 0, nil
72+
}
73+
return nil, 0, feature_dns.ErrEmptyResponse
74+
}
75+
4876
func TestUDPConnClientWriteMultiBufferResolvesPacketDestination(t *testing.T) {
4977
packetConn := &capturePacketConn{}
5078
conn := &udpConnClient{
@@ -143,3 +171,56 @@ func TestUDPConnClientResolveDestinationCachesDomainsCaseInsensitive(t *testing.
143171
t.Fatalf("unexpected cached addresses: %s != %s", first.Address, second.Address)
144172
}
145173
}
174+
175+
func TestUDPConnClientResolveDestinationUsesSeededDomainCache(t *testing.T) {
176+
lookupCalls := 0
177+
cache := utils.NewTypedSyncMap[string, net.Address]()
178+
cache.Store("dns.cloudflare.internal", net.IPAddress([]byte{162, 159, 36, 1}))
179+
conn := &udpConnClient{
180+
resolvedUDPAddr: cache,
181+
lookupIP: func(domain string) (net.Address, error) {
182+
lookupCalls++
183+
return net.IPAddress([]byte{162, 159, 46, 1}), nil
184+
},
185+
}
186+
187+
resolved, err := conn.resolveDestination(net.UDPDestination(net.DomainAddress("DNS.Cloudflare.Internal"), net.Port(53)))
188+
if err != nil {
189+
t.Fatal(err)
190+
}
191+
192+
if lookupCalls != 0 {
193+
t.Fatalf("unexpected lookup calls: %d", lookupCalls)
194+
}
195+
if resolved.Address.String() != "162.159.36.1" {
196+
t.Fatalf("unexpected seeded address: %s", resolved.Address)
197+
}
198+
}
199+
200+
func TestHandlerLookupIPForFamilyConstrainsLookupOption(t *testing.T) {
201+
client := &recordingDNSClient{}
202+
handler := &Handler{dns: client}
203+
204+
ipv4, err := handler.lookupIPForFamily("dns.cloudflare.internal", net.AddressFamilyIPv4)
205+
if err != nil {
206+
t.Fatal(err)
207+
}
208+
if !ipv4.Family().IsIPv4() {
209+
t.Fatalf("unexpected IPv4 lookup result: %s", ipv4)
210+
}
211+
if len(client.options) != 1 || !client.options[0].IPv4Enable || client.options[0].IPv6Enable {
212+
t.Fatalf("unexpected IPv4 lookup options: %+v", client.options)
213+
}
214+
215+
client.options = nil
216+
ipv6, err := handler.lookupIPForFamily("dns.cloudflare.internal", net.AddressFamilyIPv6)
217+
if err != nil {
218+
t.Fatal(err)
219+
}
220+
if !ipv6.Family().IsIPv6() {
221+
t.Fatalf("unexpected IPv6 lookup result: %s", ipv6)
222+
}
223+
if len(client.options) != 1 || client.options[0].IPv4Enable || !client.options[0].IPv6Enable {
224+
t.Fatalf("unexpected IPv6 lookup options: %+v", client.options)
225+
}
226+
}

0 commit comments

Comments
 (0)