diff --git a/infra/conf/transport_internet.go b/infra/conf/transport_internet.go index 849f0e38ee1d..44eb448df631 100644 --- a/infra/conf/transport_internet.go +++ b/infra/conf/transport_internet.go @@ -1660,16 +1660,30 @@ func (c *Sudoku) Build() (proto.Message, error) { } type Xdns struct { - Domain string `json:"domain"` + Domain json.RawMessage `json:"domain"` + + Domains []string `json:"domains"` + Resolvers []string `json:"resolvers"` } func (c *Xdns) Build() (proto.Message, error) { - if c.Domain == "" { - return nil, errors.New("empty domain") + if c.Domain != nil { + return nil, errors.PrintRemovedFeatureError("domain", "domains(server) & resolvers(client)") + } + + if len(c.Domains) == 0 && len(c.Resolvers) == 0 { + return nil, errors.New("empty domains & empty resolvers") + } + + for _, r := range c.Resolvers { + if !strings.Contains(r, "+udp://") { + return nil, errors.New("invalid resolver ", r) + } } return &xdns.Config{ - Domain: c.Domain, + Domains: c.Domains, + Resolvers: c.Resolvers, }, nil } diff --git a/transport/internet/finalmask/xdns/client.go b/transport/internet/finalmask/xdns/client.go index d6867b0a7473..6f8d97371937 100644 --- a/transport/internet/finalmask/xdns/client.go +++ b/transport/internet/finalmask/xdns/client.go @@ -9,7 +9,10 @@ import ( go_errors "errors" "io" "net" + "strconv" + "strings" "sync" + "sync/atomic" "time" "github.com/xtls/xray-core/common" @@ -34,10 +37,14 @@ type packet struct { } type xdnsConnClient struct { - net.PacketConn + conn net.PacketConn + resolverConns []net.PacketConn + resolverAddrs []*net.UDPAddr + resolverIdx uint32 + resolverSend []atomic.Uint32 clientID []byte - domain Name + domains []Name pollChan chan struct{} readQueue chan *packet @@ -48,16 +55,66 @@ type xdnsConnClient struct { } func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { - domain, err := ParseName(c.Domain) - if err != nil { - return nil, err + if len(c.Resolvers) == 0 { + return nil, errors.New("empty resolvers") + } + + var domains []Name + var servers []string + for _, rs := range c.Resolvers { + parts := strings.Split(rs, "+udp://") + if len(parts) != 2 { + return nil, errors.New("invalid resolvers") + } + domain, err := ParseName(parts[0]) + if err != nil { + return nil, err + } + domains = append(domains, domain) + servers = append(servers, parts[1]) + } + + var resolverConns []net.PacketConn + var resolverAddrs []*net.UDPAddr + var resolverSend []atomic.Uint32 + for _, rs := range servers { + h, p, err := net.SplitHostPort(rs) + if err != nil { + return nil, err + } + ip := net.ParseIP(h) + if ip == nil { + return nil, errors.New("invalid ip address") + } + port, _ := strconv.Atoi(p) + if port == 0 { + return nil, errors.New("invalid port") + } + var uc net.PacketConn + if ip.To4() != nil { + uc, err = net.ListenPacket("udp4", ":0") + } else { + uc, err = net.ListenPacket("udp6", ":0") + } + if err != nil { + for _, rc := range resolverConns { + rc.Close() + } + return nil, errors.New("failed to create resolver socket: ", err) + } + resolverConns = append(resolverConns, uc) + resolverAddrs = append(resolverAddrs, &net.UDPAddr{IP: ip, Port: port}) } + resolverSend = make([]atomic.Uint32, len(resolverConns)) conn := &xdnsConnClient{ - PacketConn: raw, + conn: raw, + resolverConns: resolverConns, + resolverAddrs: resolverAddrs, + resolverSend: resolverSend, clientID: make([]byte, 8), - domain: domain, + domains: domains, pollChan: make(chan struct{}, pollLimit), readQueue: make(chan *packet, 256), @@ -73,58 +130,70 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { } func (c *xdnsConnClient) recvLoop() { - var buf [finalmask.UDPSize]byte - - for { - if c.closed { - break - } - - n, addr, err := c.PacketConn.ReadFrom(buf[:]) - if err != nil || n == 0 { - if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) { - break - } - continue - } - - resp, err := MessageFromWireFormat(buf[:n]) - if err != nil { - errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err) - continue - } - - payload := dnsResponsePayload(&resp, c.domain) - - r := bytes.NewReader(payload) - anyPacket := false - for { - p, err := nextPacket(r) - if err != nil { - break + var wg sync.WaitGroup + + for i, rc := range c.resolverConns { + wg.Add(1) + go func() { + defer wg.Done() + + var buf [finalmask.UDPSize]byte + + for { + if c.closed { + break + } + + n, addr, err := rc.ReadFrom(buf[:]) + if err != nil { + if go_errors.Is(err, net.ErrClosed) { + break + } + continue + } + + resp, err := MessageFromWireFormat(buf[:n]) + if err != nil { + errors.LogDebug(context.Background(), addr, " xdns from wireformat err ", err) + continue + } + + payload := dnsResponsePayload(&resp, c.domains) + + r := bytes.NewReader(payload) + anyPacket := false + for { + p, err := nextPacket(r) + if err != nil { + break + } + anyPacket = true + + buf := make([]byte, len(p)) + copy(buf, p) + select { + case c.readQueue <- &packet{ + p: buf, + addr: addr, + }: + default: + errors.LogDebug(context.Background(), addr, " mask read err queue full") + } + } + + if anyPacket { + c.resolverSend[i].Store(0) + select { + case c.pollChan <- struct{}{}: + default: + } + } } - anyPacket = true - - buf := make([]byte, len(p)) - copy(buf, p) - select { - case c.readQueue <- &packet{ - p: buf, - addr: addr, - }: - default: - errors.LogDebug(context.Background(), addr, " mask read err queue full") - } - } - - if anyPacket { - select { - case c.pollChan <- struct{}{}: - default: - } - } + }() } + wg.Wait() + errors.LogDebug(context.Background(), "xdns closed") close(c.pollChan) @@ -138,8 +207,6 @@ func (c *xdnsConnClient) recvLoop() { } func (c *xdnsConnClient) sendLoop() { - var addr net.Addr - pollDelay := initPollDelay pollTimer := time.NewTimer(pollDelay) for { @@ -158,17 +225,14 @@ func (c *xdnsConnClient) sendLoop() { } if p != nil { - addr = p.addr - select { case <-c.pollChan: default: } - } else if addr != nil { - encoded, _ := encode(nil, c.clientID, c.domain) + } else { + encoded, _ := encode(nil, c.clientID, c.domains[c.resolverIdx]) p = &packet{ - p: encoded, - addr: addr, + p: encoded, } } @@ -189,10 +253,16 @@ func (c *xdnsConnClient) sendLoop() { return } - if p != nil { - _, err := c.PacketConn.WriteTo(p.p, p.addr) - if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) { - c.closed = true + cur := c.resolverIdx + curSend := c.resolverSend[c.resolverIdx].Add(1) + _, _ = c.resolverConns[c.resolverIdx].WriteTo(p.p, c.resolverAddrs[c.resolverIdx]) + for { + c.resolverIdx += 1 + c.resolverIdx %= uint32(len(c.resolverConns)) + if c.resolverIdx == cur { + break + } + if c.resolverSend[c.resolverIdx].Load() < curSend { break } } @@ -220,7 +290,7 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { return 0, io.ErrClosedPipe } - encoded, err := encode(p, c.clientID, c.domain) + encoded, err := encode(p, c.clientID, c.domains[c.resolverIdx%uint32(len(c.resolverConns))]) if err != nil { errors.LogDebug(context.Background(), addr, " xdns wireformat err ", err, " ", len(p)) return 0, nil @@ -240,7 +310,35 @@ func (c *xdnsConnClient) WriteTo(p []byte, addr net.Addr) (n int, err error) { func (c *xdnsConnClient) Close() error { c.closed = true - return c.PacketConn.Close() + for _, rc := range c.resolverConns { + rc.Close() + } + return c.conn.Close() +} + +func (c *xdnsConnClient) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *xdnsConnClient) SetDeadline(t time.Time) error { + for _, rc := range c.resolverConns { + rc.SetDeadline(t) + } + return c.conn.SetDeadline(t) +} + +func (c *xdnsConnClient) SetReadDeadline(t time.Time) error { + for _, rc := range c.resolverConns { + rc.SetReadDeadline(t) + } + return c.conn.SetReadDeadline(t) +} + +func (c *xdnsConnClient) SetWriteDeadline(t time.Time) error { + for _, rc := range c.resolverConns { + rc.SetWriteDeadline(t) + } + return c.conn.SetWriteDeadline(t) } func encode(p []byte, clientID []byte, domain Name) ([]byte, error) { @@ -332,7 +430,7 @@ func nextPacket(r *bytes.Reader) ([]byte, error) { return p, err } -func dnsResponsePayload(resp *Message, domain Name) []byte { +func dnsResponsePayload(resp *Message, domains []Name) []byte { if resp.Flags&0x8000 != 0x8000 { return nil } @@ -345,7 +443,13 @@ func dnsResponsePayload(resp *Message, domain Name) []byte { } answer := resp.Answer[0] - _, ok := answer.Name.TrimSuffix(domain) + var ok bool + for _, domain := range domains { + _, ok = answer.Name.TrimSuffix(domain) + if ok { + break + } + } if !ok { return nil } diff --git a/transport/internet/finalmask/xdns/config.go b/transport/internet/finalmask/xdns/config.go index 157102dafa2b..dbd78a28633d 100644 --- a/transport/internet/finalmask/xdns/config.go +++ b/transport/internet/finalmask/xdns/config.go @@ -2,15 +2,27 @@ package xdns import ( "net" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/hysteria/udphop" ) func (c *Config) UDP() { } func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + _, ok1 := raw.(*internet.FakePacketConn) + _, ok2 := raw.(*udphop.UdpHopPacketConn) + if level != 0 || ok1 || ok2 { + return nil, errors.New("xdns requires being at the outermost level") + } return NewConnClient(c, raw) } func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { + if level != 0 { + return nil, errors.New("xdns requires being at the outermost level") + } return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/xdns/config.pb.go b/transport/internet/finalmask/xdns/config.pb.go index 279240510137..e1f06aa930d2 100644 --- a/transport/internet/finalmask/xdns/config.pb.go +++ b/transport/internet/finalmask/xdns/config.pb.go @@ -23,7 +23,8 @@ const ( type Config struct { state protoimpl.MessageState `protogen:"open.v1"` - Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` + Domains []string `protobuf:"bytes,1,rep,name=domains,proto3" json:"domains,omitempty"` + Resolvers []string `protobuf:"bytes,2,rep,name=resolvers,proto3" json:"resolvers,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -58,20 +59,28 @@ func (*Config) Descriptor() ([]byte, []int) { return file_transport_internet_finalmask_xdns_config_proto_rawDescGZIP(), []int{0} } -func (x *Config) GetDomain() string { +func (x *Config) GetDomains() []string { if x != nil { - return x.Domain + return x.Domains } - return "" + return nil +} + +func (x *Config) GetResolvers() []string { + if x != nil { + return x.Resolvers + } + return nil } var File_transport_internet_finalmask_xdns_config_proto protoreflect.FileDescriptor const file_transport_internet_finalmask_xdns_config_proto_rawDesc = "" + "\n" + - ".transport/internet/finalmask/xdns/config.proto\x12&xray.transport.internet.finalmask.xdns\" \n" + - "\x06Config\x12\x16\n" + - "\x06domain\x18\x01 \x01(\tR\x06domainB\x94\x01\n" + + ".transport/internet/finalmask/xdns/config.proto\x12&xray.transport.internet.finalmask.xdns\"@\n" + + "\x06Config\x12\x18\n" + + "\adomains\x18\x01 \x03(\tR\adomains\x12\x1c\n" + + "\tresolvers\x18\x02 \x03(\tR\tresolversB\x94\x01\n" + "*com.xray.transport.internet.finalmask.xdnsP\x01Z;github.com/xtls/xray-core/transport/internet/finalmask/xdns\xaa\x02&Xray.Transport.Internet.Finalmask.Xdnsb\x06proto3" var ( diff --git a/transport/internet/finalmask/xdns/config.proto b/transport/internet/finalmask/xdns/config.proto index e1c717709dea..b859b17aee11 100644 --- a/transport/internet/finalmask/xdns/config.proto +++ b/transport/internet/finalmask/xdns/config.proto @@ -7,6 +7,6 @@ option java_package = "com.xray.transport.internet.finalmask.xdns"; option java_multiple_files = true; message Config { - string domain = 1; -} - + repeated string domains = 1; + repeated string resolvers = 2; +} \ No newline at end of file diff --git a/transport/internet/finalmask/xdns/dns_test.go b/transport/internet/finalmask/xdns/dns_test.go index aa163476d9f1..2ddc9da50273 100644 --- a/transport/internet/finalmask/xdns/dns_test.go +++ b/transport/internet/finalmask/xdns/dns_test.go @@ -559,6 +559,7 @@ func TestEncodeRDataTXT(t *testing.T) { } fmt.Println(EncodeRDataTXT(nil)) + fmt.Println(computeMaxEncodedPayload(maxUDPPayload)) } func TestRDataTXTRoundTrip(t *testing.T) { diff --git a/transport/internet/finalmask/xdns/server.go b/transport/internet/finalmask/xdns/server.go index ec2f18f9a8ac..c96149ad5c01 100644 --- a/transport/internet/finalmask/xdns/server.go +++ b/transport/internet/finalmask/xdns/server.go @@ -52,7 +52,7 @@ type queue struct { type xdnsConnServer struct { net.PacketConn - domain Name + domains []Name ch chan *record readQueue chan *packet @@ -63,15 +63,22 @@ type xdnsConnServer struct { } func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { - domain, err := ParseName(c.Domain) - if err != nil { - return nil, err + if len(c.Domains) == 0 { + return nil, errors.New("empty domains") + } + domains := make([]Name, 0, len(c.Domains)) + for _, domain := range c.Domains { + domain, err := ParseName(domain) + if err != nil { + return nil, err + } + domains = append(domains, domain) } conn := &xdnsConnServer{ PacketConn: raw, - domain: domain, + domains: domains, ch: make(chan *record, 500), readQueue: make(chan *packet, 512), @@ -156,8 +163,8 @@ func (c *xdnsConnServer) recvLoop() { } n, addr, err := c.PacketConn.ReadFrom(buf[:]) - if err != nil || n == 0 { - if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.EOF) { + if err != nil { + if go_errors.Is(err, net.ErrClosed) { break } continue @@ -169,7 +176,7 @@ func (c *xdnsConnServer) recvLoop() { continue } - resp, payload := responseFor(&query, c.domain) + resp, payload := responseFor(&query, c.domains) var clientID [8]byte n = copy(clientID[:], payload) @@ -321,7 +328,7 @@ func (c *xdnsConnServer) sendLoop() { } _, err = c.PacketConn.WriteTo(buf, rec.Addr) - if go_errors.Is(err, net.ErrClosed) || go_errors.Is(err, io.ErrClosedPipe) { + if go_errors.Is(err, net.ErrClosed) { c.closed = true break } @@ -399,7 +406,7 @@ func nextPacketServer(r *bytes.Reader) ([]byte, error) { } } -func responseFor(query *Message, domain Name) (*Message, []byte) { +func responseFor(query *Message, domains []Name) (*Message, []byte) { resp := &Message{ ID: query.ID, Flags: 0x8000, @@ -447,7 +454,14 @@ func responseFor(query *Message, domain Name) (*Message, []byte) { } question := query.Question[0] - prefix, ok := question.Name.TrimSuffix(domain) + var prefix Name + var ok bool + for _, domain := range domains { + prefix, ok = question.Name.TrimSuffix(domain) + if ok { + break + } + } if !ok { resp.Flags |= RcodeNameError return resp, nil @@ -525,7 +539,7 @@ func computeMaxEncodedPayload(limit int) int { }, }, } - resp, _ := responseFor(query, [][]byte{}) + resp, _ := responseFor(query, []Name{[][]byte{}}) resp.Answer = []RR{ { diff --git a/transport/internet/finalmask/xicmp/client.go b/transport/internet/finalmask/xicmp/client.go index 6ceaf2671643..d738b125098d 100644 --- a/transport/internet/finalmask/xicmp/client.go +++ b/transport/internet/finalmask/xicmp/client.go @@ -10,9 +10,7 @@ import ( "github.com/xtls/xray-core/common/crypto" "github.com/xtls/xray-core/common/errors" - "github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet/finalmask" - "github.com/xtls/xray-core/transport/internet/hysteria/udphop" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -54,13 +52,7 @@ type xicmpConnClient struct { mutex sync.Mutex } -func NewConnClient(c *Config, raw net.PacketConn, level int) (net.PacketConn, error) { - _, ok1 := raw.(*internet.FakePacketConn) - _, ok2 := raw.(*udphop.UdpHopPacketConn) - if level != 0 || ok1 || ok2 { - return nil, errors.New("xicmp requires being at the outermost level") - } - +func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) { network := "ip4:icmp" typ := icmp.Type(ipv4.ICMPTypeEcho) proto := 1 diff --git a/transport/internet/finalmask/xicmp/config.go b/transport/internet/finalmask/xicmp/config.go index c570ce96817e..fdcb02ae701f 100644 --- a/transport/internet/finalmask/xicmp/config.go +++ b/transport/internet/finalmask/xicmp/config.go @@ -2,15 +2,27 @@ package xicmp import ( "net" + + "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/transport/internet" + "github.com/xtls/xray-core/transport/internet/hysteria/udphop" ) func (c *Config) UDP() { } func (c *Config) WrapPacketConnClient(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { - return NewConnClient(c, raw, level) + _, ok1 := raw.(*internet.FakePacketConn) + _, ok2 := raw.(*udphop.UdpHopPacketConn) + if level != 0 || ok1 || ok2 { + return nil, errors.New("xicmp requires being at the outermost level") + } + return NewConnClient(c, raw) } func (c *Config) WrapPacketConnServer(raw net.PacketConn, level int, levelCount int) (net.PacketConn, error) { - return NewConnServer(c, raw, level) + if level != 0 { + return nil, errors.New("xicmp requires being at the outermost level") + } + return NewConnServer(c, raw) } diff --git a/transport/internet/finalmask/xicmp/server.go b/transport/internet/finalmask/xicmp/server.go index 79a5d010dd84..94012f019802 100644 --- a/transport/internet/finalmask/xicmp/server.go +++ b/transport/internet/finalmask/xicmp/server.go @@ -50,11 +50,7 @@ type xicmpConnServer struct { mutex sync.Mutex } -func NewConnServer(c *Config, raw net.PacketConn, level int) (net.PacketConn, error) { - if level != 0 { - return nil, errors.New("xicmp requires being at the outermost level") - } - +func NewConnServer(c *Config, raw net.PacketConn) (net.PacketConn, error) { network := "ip4:icmp" typ := icmp.Type(ipv4.ICMPTypeEchoReply) proto := 1