Skip to content

Commit 1803471

Browse files
committed
endpoint: Fix UDP resolved destination
1 parent 3de56d3 commit 1803471

4 files changed

Lines changed: 82 additions & 27 deletions

File tree

common/dialer/dialer.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,7 @@ type ParallelNetworkDialer interface {
145145
DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error)
146146
ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error)
147147
}
148+
149+
type PacketDialerWithDestination interface {
150+
ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error)
151+
}

protocol/tailscale/endpoint.go

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ import (
6363
var (
6464
_ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil)
6565
_ adapter.DirectRouteOutbound = (*Endpoint)(nil)
66+
_ dialer.PacketDialerWithDestination = (*Endpoint)(nil)
6667
)
6768

6869
func init() {
@@ -518,19 +519,7 @@ func (t *Endpoint) DialContext(ctx context.Context, network string, destination
518519
}
519520
}
520521

521-
func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
522-
t.logger.InfoContext(ctx, "outbound packet connection to ", destination)
523-
if destination.IsFqdn() {
524-
destinationAddresses, err := t.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
525-
if err != nil {
526-
return nil, err
527-
}
528-
packetConn, _, err := N.ListenSerial(ctx, t, destination, destinationAddresses)
529-
if err != nil {
530-
return nil, err
531-
}
532-
return packetConn, err
533-
}
522+
func (t *Endpoint) listenPacketWithAddress(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
534523
addr4, addr6 := t.server.TailscaleIPs()
535524
bind := tcpip.FullAddress{
536525
NIC: 1,
@@ -556,6 +545,44 @@ func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
556545
return udpConn, nil
557546
}
558547

548+
func (t *Endpoint) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) {
549+
t.logger.InfoContext(ctx, "outbound packet connection to ", destination)
550+
if destination.IsFqdn() {
551+
destinationAddresses, err := t.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
552+
if err != nil {
553+
return nil, netip.Addr{}, err
554+
}
555+
var errors []error
556+
for _, address := range destinationAddresses {
557+
packetConn, packetErr := t.listenPacketWithAddress(ctx, M.SocksaddrFrom(address, destination.Port))
558+
if packetErr == nil {
559+
return packetConn, address, nil
560+
}
561+
errors = append(errors, packetErr)
562+
}
563+
return nil, netip.Addr{}, E.Errors(errors...)
564+
}
565+
packetConn, err := t.listenPacketWithAddress(ctx, destination)
566+
if err != nil {
567+
return nil, netip.Addr{}, err
568+
}
569+
if destination.IsIP() {
570+
return packetConn, destination.Addr, nil
571+
}
572+
return packetConn, netip.Addr{}, nil
573+
}
574+
575+
func (t *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
576+
packetConn, destinationAddress, err := t.ListenPacketWithDestination(ctx, destination)
577+
if err != nil {
578+
return nil, err
579+
}
580+
if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
581+
return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
582+
}
583+
return packetConn, nil
584+
}
585+
559586
func (t *Endpoint) PrepareConnection(network string, source M.Socksaddr, destination M.Socksaddr, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
560587
tsFilter := t.filter.Load()
561588
if tsFilter != nil {

protocol/wireguard/endpoint.go

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ import (
2424
"github.com/sagernet/sing/service"
2525
)
2626

27-
var _ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil)
27+
var (
28+
_ adapter.OutboundWithPreferredRoutes = (*Endpoint)(nil)
29+
_ dialer.PacketDialerWithDestination = (*Endpoint)(nil)
30+
)
2831

2932
func RegisterEndpoint(registry *endpoint.Registry) {
3033
endpoint.Register[option.WireGuardEndpointOptions](registry, C.TypeWireGuard, NewEndpoint)
@@ -219,20 +222,34 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination
219222
return w.endpoint.DialContext(ctx, network, destination)
220223
}
221224

222-
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
225+
func (w *Endpoint) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) {
223226
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
224227
if destination.IsFqdn() {
225228
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
226229
if err != nil {
227-
return nil, err
228-
}
229-
packetConn, _, err := N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
230-
if err != nil {
231-
return nil, err
230+
return nil, netip.Addr{}, err
232231
}
233-
return packetConn, err
232+
return N.ListenSerial(ctx, w.endpoint, destination, destinationAddresses)
233+
}
234+
packetConn, err := w.endpoint.ListenPacket(ctx, destination)
235+
if err != nil {
236+
return nil, netip.Addr{}, err
237+
}
238+
if destination.IsIP() {
239+
return packetConn, destination.Addr, nil
240+
}
241+
return packetConn, netip.Addr{}, nil
242+
}
243+
244+
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
245+
packetConn, destinationAddress, err := w.ListenPacketWithDestination(ctx, destination)
246+
if err != nil {
247+
return nil, err
248+
}
249+
if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
250+
return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
234251
}
235-
return w.endpoint.ListenPacket(ctx, destination)
252+
return packetConn, nil
236253
}
237254

238255
func (w *Endpoint) PreferredDomain(domain string) bool {

route/conn.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
188188
} else {
189189
if len(metadata.DestinationAddresses) > 0 {
190190
remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
191+
} else if packetDialer, withDestination := this.(dialer.PacketDialerWithDestination); withDestination {
192+
remotePacketConn, destinationAddress, err = packetDialer.ListenPacketWithDestination(ctx, metadata.Destination)
191193
} else {
192194
remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
193195
}
@@ -218,11 +220,16 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
218220
}
219221
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
220222
natConn.UpdateDestination(destinationAddress)
221-
} else if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
222-
if metadata.UDPDisableDomainUnmapping {
223-
remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
224-
} else {
225-
remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
223+
} else {
224+
destination := M.SocksaddrFrom(destinationAddress, metadata.Destination.Port)
225+
if metadata.Destination != destination {
226+
if metadata.UDPDisableDomainUnmapping {
227+
remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), destination, originDestination)
228+
} else {
229+
remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), destination, originDestination)
230+
}
231+
} else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
232+
remotePacketConn = bufio.NewDestinationNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
226233
}
227234
}
228235
} else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {

0 commit comments

Comments
 (0)