Skip to content

Commit 10f2bcd

Browse files
author
null
committed
tnet
1 parent 0889025 commit 10f2bcd

2 files changed

Lines changed: 131 additions & 20 deletions

File tree

proxy/wireguard/client.go

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package wireguard
33
import (
44
"context"
55
"fmt"
6-
gonet "net"
76
"net/netip"
87
reflect "reflect"
98
"strings"
@@ -126,22 +125,6 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
126125
}
127126
}
128127

129-
var addr netip.Addr
130-
if ob.Target.Address.Family().IsDomain() {
131-
ip, err := h.resolveRemote(ob.Target.Address.String())
132-
if err != nil {
133-
return errors.New("failed to resolve domain").Base(err)
134-
}
135-
addr, _ = netip.AddrFromSlice(ip)
136-
} else {
137-
addr, _ = netip.AddrFromSlice(ob.Target.Address.IP())
138-
}
139-
140-
addrPort := netip.AddrPortFrom(addr, ob.Target.Port.Value())
141-
if !addrPort.IsValid() {
142-
return errors.New("invalid target ", ob.Target)
143-
}
144-
145128
var newCtx context.Context
146129
var newCancel context.CancelFunc
147130
if session.TimeoutOnlyFromContext(ctx) {
@@ -166,23 +149,33 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
166149

167150
switch ob.Target.Network {
168151
case net.Network_TCP:
169-
conn, err := h.tnet.DialContextTCPAddrPort(ctx, addrPort)
152+
conn, err := h.tnet.Dial("tcp", ob.Target.NetAddr())
170153
if err != nil {
171154
return errors.New("failed to create TCP connection").Base(err)
172155
}
173156
defer conn.Close()
174157
reader = buf.NewReader(conn)
175158
writer = buf.NewWriter(conn)
176159
case net.Network_UDP:
177-
conn, err := h.tnet.DialUDPAddrPort(netip.AddrPort{}, addrPort)
160+
var dest *net.UDPAddr
161+
if ob.Target.Address.Family().IsDomain() {
162+
ip, err := h.resolveRemote(ob.Target.Address.String())
163+
if err != nil {
164+
return errors.New("failed to resolve domain").Base(err)
165+
}
166+
dest = &net.UDPAddr{IP: ip, Port: int(ob.Target.Port)}
167+
} else {
168+
dest = &net.UDPAddr{IP: ob.Target.Address.IP(), Port: int(ob.Target.Port)}
169+
}
170+
conn, err := h.tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPort{})
178171
if err != nil {
179172
return errors.New("failed to create UDP connection").Base(err)
180173
}
181174
defer conn.Close()
182175
c := &udpConnClient{
183176
PacketConn: conn.(*internet.PacketConnWrapper).PacketConn,
184177
resolveFunc: h.resolveRemote,
185-
dest: gonet.UDPAddrFromAddrPort(addrPort),
178+
dest: dest,
186179
}
187180
reader = c
188181
writer = c

proxy/wireguard/netstack.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515
"net"
1616
"net/netip"
1717
"os"
18+
"regexp"
19+
"strconv"
1820
"strings"
1921
"syscall"
2022
"time"
@@ -246,6 +248,9 @@ var (
246248
errServerTemporarilyMisbehaving = errors.New("server misbehaving")
247249
errCanceled = errors.New("operation was canceled")
248250
errTimeout = errors.New("i/o timeout")
251+
errNumericPort = errors.New("port must be numeric")
252+
errNoSuitableAddress = errors.New("no suitable address found")
253+
errMissingAddress = errors.New("missing address")
249254
)
250255

251256
func (net *Net) LookupHost(host string) (addrs []string, err error) {
@@ -688,3 +693,116 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
688693
}
689694
return saddrs, nil
690695
}
696+
697+
func partialDeadline(now, deadline time.Time, addrsRemaining int) (time.Time, error) {
698+
if deadline.IsZero() {
699+
return deadline, nil
700+
}
701+
timeRemaining := deadline.Sub(now)
702+
if timeRemaining <= 0 {
703+
return time.Time{}, errTimeout
704+
}
705+
timeout := timeRemaining / time.Duration(addrsRemaining)
706+
const saneMinimum = 2 * time.Second
707+
if timeout < saneMinimum {
708+
if timeRemaining < saneMinimum {
709+
timeout = timeRemaining
710+
} else {
711+
timeout = saneMinimum
712+
}
713+
}
714+
return now.Add(timeout), nil
715+
}
716+
717+
var protoSplitter = regexp.MustCompile(`^(tcp)(4|6)?$`)
718+
719+
func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
720+
if ctx == nil {
721+
panic("nil context")
722+
}
723+
var acceptV4, acceptV6 bool
724+
matches := protoSplitter.FindStringSubmatch(network)
725+
if matches == nil {
726+
return nil, &net.OpError{Op: "dial", Err: net.UnknownNetworkError(network)}
727+
} else if len(matches[2]) == 0 {
728+
acceptV4 = true
729+
acceptV6 = true
730+
} else {
731+
acceptV4 = matches[2][0] == '4'
732+
acceptV6 = !acceptV4
733+
}
734+
var host string
735+
var port int
736+
var sport string
737+
var err error
738+
host, sport, err = net.SplitHostPort(address)
739+
if err != nil {
740+
return nil, &net.OpError{Op: "dial", Err: err}
741+
}
742+
port, err = strconv.Atoi(sport)
743+
if err != nil || port < 0 || port > 65535 {
744+
return nil, &net.OpError{Op: "dial", Err: errNumericPort}
745+
}
746+
allAddr, err := tnet.LookupContextHost(ctx, host)
747+
if err != nil {
748+
return nil, &net.OpError{Op: "dial", Err: err}
749+
}
750+
var addrs []netip.AddrPort
751+
for _, addr := range allAddr {
752+
ip, err := netip.ParseAddr(addr)
753+
if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
754+
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
755+
}
756+
}
757+
if len(addrs) == 0 && len(allAddr) != 0 {
758+
return nil, &net.OpError{Op: "dial", Err: errNoSuitableAddress}
759+
}
760+
761+
var firstErr error
762+
for i, addr := range addrs {
763+
select {
764+
case <-ctx.Done():
765+
err := ctx.Err()
766+
if err == context.Canceled {
767+
err = errCanceled
768+
} else if err == context.DeadlineExceeded {
769+
err = errTimeout
770+
}
771+
return nil, &net.OpError{Op: "dial", Err: err}
772+
default:
773+
}
774+
775+
dialCtx := ctx
776+
if deadline, hasDeadline := ctx.Deadline(); hasDeadline {
777+
partialDeadline, err := partialDeadline(time.Now(), deadline, len(addrs)-i)
778+
if err != nil {
779+
if firstErr == nil {
780+
firstErr = &net.OpError{Op: "dial", Err: err}
781+
}
782+
break
783+
}
784+
if partialDeadline.Before(deadline) {
785+
var cancel context.CancelFunc
786+
dialCtx, cancel = context.WithDeadline(ctx, partialDeadline)
787+
defer cancel()
788+
}
789+
}
790+
791+
var c net.Conn
792+
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
793+
if err == nil {
794+
return c, nil
795+
}
796+
if firstErr == nil {
797+
firstErr = err
798+
}
799+
}
800+
if firstErr == nil {
801+
firstErr = &net.OpError{Op: "dial", Err: errMissingAddress}
802+
}
803+
return nil, firstErr
804+
}
805+
806+
func (tnet *Net) Dial(network, address string) (net.Conn, error) {
807+
return tnet.DialContext(context.Background(), network, address)
808+
}

0 commit comments

Comments
 (0)