Skip to content

Commit 28196b7

Browse files
committed
Fix UDP DNS over WireGuard domain destinations
1 parent 1bdb488 commit 28196b7

4 files changed

Lines changed: 268 additions & 4 deletions

File tree

app/dns/nameserver_udp.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,6 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
131131
newReq.msg = &newMsg
132132
s.addPendingRequest(&newReq)
133133
b, _ := dns.PackMessage(newReq.msg)
134-
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
135-
b.UDP = &copyDest
136134
s.udpServer.Dispatch(toDnsContext(newReq.ctx, s.address.String()), *s.address, b)
137135
return
138136
}
@@ -179,8 +177,6 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, noResponseErrCh chan<
179177
}
180178
return
181179
}
182-
copyDest := net.UDPDestination(s.address.Address, s.address.Port)
183-
b.UDP = &copyDest
184180
s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b)
185181
}
186182
}

app/dns/nameserver_udp_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package dns
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/xtls/xray-core/common/buf"
8+
"github.com/xtls/xray-core/common/net"
9+
"github.com/xtls/xray-core/core"
10+
dns_feature "github.com/xtls/xray-core/features/dns"
11+
"github.com/xtls/xray-core/features/routing"
12+
"github.com/xtls/xray-core/transport"
13+
"github.com/xtls/xray-core/transport/pipe"
14+
)
15+
16+
type captureDispatcher struct {
17+
link *transport.Link
18+
dest net.Destination
19+
}
20+
21+
func (d *captureDispatcher) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) {
22+
d.dest = dest
23+
return d.link, nil
24+
}
25+
26+
func (d *captureDispatcher) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error {
27+
return nil
28+
}
29+
30+
func (d *captureDispatcher) Start() error {
31+
return nil
32+
}
33+
34+
func (d *captureDispatcher) Close() error {
35+
return nil
36+
}
37+
38+
func (d *captureDispatcher) Type() interface{} {
39+
return routing.DispatcherType()
40+
}
41+
42+
func TestClassicNameServerSendQueryDoesNotOverridePacketDestination(t *testing.T) {
43+
uplinkReader, uplinkWriter := pipe.New(pipe.WithSizeLimit(1024))
44+
downlinkReader, downlinkWriter := pipe.New(pipe.WithSizeLimit(1024))
45+
defer uplinkReader.Interrupt()
46+
defer uplinkWriter.Close()
47+
defer downlinkReader.Interrupt()
48+
defer downlinkWriter.Close()
49+
50+
dispatcher := &captureDispatcher{
51+
link: &transport.Link{
52+
Reader: downlinkReader,
53+
Writer: uplinkWriter,
54+
},
55+
}
56+
server := NewClassicNameServer(net.UDPDestination(net.DomainAddress("resolver.test"), 53), dispatcher, false, false, 0, nil)
57+
defer server.requestsCleanup.Close()
58+
defer server.udpServer.RemoveRay()
59+
60+
instance, err := core.New(&core.Config{})
61+
if err != nil {
62+
t.Fatal(err)
63+
}
64+
ctx := context.WithValue(context.Background(), core.XrayKey(1), instance)
65+
66+
server.sendQuery(ctx, nil, "service.test.", dns_feature.IPOption{IPv4Enable: true})
67+
68+
mb, err := uplinkReader.ReadMultiBuffer()
69+
if err != nil {
70+
t.Fatal(err)
71+
}
72+
defer buf.ReleaseMulti(mb)
73+
74+
if len(mb) == 0 {
75+
t.Fatal("expected DNS query payload")
76+
}
77+
if mb[0].UDP != nil {
78+
t.Fatalf("expected DNS query payload without packet destination override, got %v", mb[0].UDP)
79+
}
80+
if dispatcher.dest.String() != "udp:resolver.test:53" {
81+
t.Fatalf("unexpected dispatch destination: %v", dispatcher.dest)
82+
}
83+
}

proxy/wireguard/client.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
185185
}
186186
addr = net.IPAddress(ips[dice.Roll(len(ips))])
187187
}
188+
destination.Address = addr
188189

189190
var newCtx context.Context
190191
var newCancel context.CancelFunc

proxy/wireguard/client_test.go

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
package wireguard
2+
3+
import (
4+
"context"
5+
"io"
6+
stdnet "net"
7+
"net/netip"
8+
"testing"
9+
"time"
10+
11+
"github.com/xtls/xray-core/common/buf"
12+
xnet "github.com/xtls/xray-core/common/net"
13+
"github.com/xtls/xray-core/common/session"
14+
feature_dns "github.com/xtls/xray-core/features/dns"
15+
"github.com/xtls/xray-core/features/policy"
16+
"github.com/xtls/xray-core/transport"
17+
"github.com/xtls/xray-core/transport/internet"
18+
"github.com/xtls/xray-core/transport/internet/stat"
19+
"github.com/xtls/xray-core/transport/pipe"
20+
wgconn "golang.zx2c4.com/wireguard/conn"
21+
)
22+
23+
type capturePacketConn struct {
24+
writtenAddr stdnet.Addr
25+
}
26+
27+
func (c *capturePacketConn) Read(p []byte) (int, error) {
28+
return 0, io.EOF
29+
}
30+
31+
func (c *capturePacketConn) Write(p []byte) (int, error) {
32+
return c.WriteTo(p, nil)
33+
}
34+
35+
func (c *capturePacketConn) Close() error {
36+
return nil
37+
}
38+
39+
func (c *capturePacketConn) LocalAddr() stdnet.Addr {
40+
return &stdnet.UDPAddr{IP: stdnet.IPv4zero, Port: 0}
41+
}
42+
43+
func (c *capturePacketConn) RemoteAddr() stdnet.Addr {
44+
return nil
45+
}
46+
47+
func (c *capturePacketConn) SetDeadline(t time.Time) error {
48+
return nil
49+
}
50+
51+
func (c *capturePacketConn) SetReadDeadline(t time.Time) error {
52+
return nil
53+
}
54+
55+
func (c *capturePacketConn) SetWriteDeadline(t time.Time) error {
56+
return nil
57+
}
58+
59+
func (c *capturePacketConn) ReadFrom(p []byte) (int, stdnet.Addr, error) {
60+
return 0, nil, io.EOF
61+
}
62+
63+
func (c *capturePacketConn) WriteTo(p []byte, addr stdnet.Addr) (int, error) {
64+
c.writtenAddr = addr
65+
return len(p), nil
66+
}
67+
68+
type captureTunnel struct {
69+
udpConn *capturePacketConn
70+
}
71+
72+
func (t *captureTunnel) BuildDevice(ipc string, bind wgconn.Bind) error {
73+
return nil
74+
}
75+
76+
func (t *captureTunnel) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (stdnet.Conn, error) {
77+
return nil, nil
78+
}
79+
80+
func (t *captureTunnel) DialUDPAddrPort(laddr, raddr netip.AddrPort) (stdnet.Conn, error) {
81+
return t.udpConn, nil
82+
}
83+
84+
func (t *captureTunnel) Close() error {
85+
return nil
86+
}
87+
88+
type staticDNSClient struct{}
89+
90+
func (c *staticDNSClient) Type() interface{} {
91+
return feature_dns.ClientType()
92+
}
93+
94+
func (c *staticDNSClient) Start() error {
95+
return nil
96+
}
97+
98+
func (c *staticDNSClient) Close() error {
99+
return nil
100+
}
101+
102+
func (c *staticDNSClient) LookupIP(domain string, option feature_dns.IPOption) ([]stdnet.IP, uint32, error) {
103+
return []stdnet.IP{stdnet.IPv4(192, 0, 2, 1)}, 0, nil
104+
}
105+
106+
type staticPolicyManager struct{}
107+
108+
func (m staticPolicyManager) Type() interface{} {
109+
return policy.ManagerType()
110+
}
111+
112+
func (m staticPolicyManager) Start() error {
113+
return nil
114+
}
115+
116+
func (m staticPolicyManager) Close() error {
117+
return nil
118+
}
119+
120+
func (m staticPolicyManager) ForLevel(level uint32) policy.Session {
121+
return policy.SessionDefault()
122+
}
123+
124+
func (m staticPolicyManager) ForSystem() policy.System {
125+
return policy.System{}
126+
}
127+
128+
type noopDialer struct{}
129+
130+
func (d *noopDialer) Dial(ctx context.Context, destination xnet.Destination) (stat.Connection, error) {
131+
return nil, nil
132+
}
133+
134+
func (d *noopDialer) DestIpAddress() stdnet.IP {
135+
return nil
136+
}
137+
138+
func (d *noopDialer) SetOutboundGateway(ctx context.Context, ob *session.Outbound) {}
139+
140+
func TestProcessStoresResolvedDomainDestinationForUDP(t *testing.T) {
141+
packetConn := &capturePacketConn{}
142+
dialer := &noopDialer{}
143+
handler := &Handler{
144+
conf: &DeviceConfig{DomainStrategy: DeviceConfig_FORCE_IP4},
145+
net: &captureTunnel{udpConn: packetConn},
146+
bind: &netBindClient{dialer: dialer},
147+
policyManager: staticPolicyManager{},
148+
dns: &staticDNSClient{},
149+
hasIPv4: true,
150+
}
151+
152+
uplinkReader, uplinkWriter := pipe.New(pipe.WithoutSizeLimit())
153+
downlinkReader, downlinkWriter := pipe.New(pipe.WithoutSizeLimit())
154+
defer uplinkReader.Interrupt()
155+
defer uplinkWriter.Close()
156+
defer downlinkReader.Interrupt()
157+
defer downlinkWriter.Close()
158+
159+
payload := buf.FromBytes([]byte("dns query"))
160+
if err := uplinkWriter.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
161+
t.Fatal(err)
162+
}
163+
if err := uplinkWriter.Close(); err != nil {
164+
t.Fatal(err)
165+
}
166+
167+
ctx := session.ContextWithOutbounds(context.Background(), []*session.Outbound{
168+
{Target: xnet.UDPDestination(xnet.DomainAddress("resolver.test"), 53)},
169+
})
170+
err := handler.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, dialer)
171+
if err != nil {
172+
t.Fatal(err)
173+
}
174+
175+
addr, ok := packetConn.writtenAddr.(*stdnet.UDPAddr)
176+
if !ok {
177+
t.Fatalf("expected UDPAddr write target, got %T", packetConn.writtenAddr)
178+
}
179+
if !addr.IP.Equal(stdnet.IPv4(192, 0, 2, 1)) || addr.Port != 53 {
180+
t.Fatalf("unexpected write target: %v", addr)
181+
}
182+
183+
var _ internet.Dialer = dialer
184+
}

0 commit comments

Comments
 (0)