Skip to content

Commit bcade97

Browse files
committed
fix(xdns): address code review findings
- Change resolverIdx from plain uint32 to atomic.Uint32 for safe round-robin counter access, matching the original design intent - Track direct-mode recvLoop goroutine with recvWg so Close() waits for it to finish before returning - Convert TestParseResolverAddr to use t.Run subtests and add error test cases for empty string and empty host inputs - Add TestMultiResolverDistribution to verify queries distribute across multiple resolvers via round-robin
1 parent b8a3771 commit bcade97

2 files changed

Lines changed: 102 additions & 19 deletions

File tree

transport/internet/finalmask/xdns/client.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ type xdnsConnClient struct {
5858
domain Name
5959

6060
resolverConns []*resolverConn
61-
resolverIdx uint32
61+
resolverIdx atomic.Uint32
6262
recvWg sync.WaitGroup
6363
sendWg sync.WaitGroup
6464

@@ -114,7 +114,11 @@ func NewConnClient(c *Config, raw net.PacketConn) (net.PacketConn, error) {
114114
}(rc.conn)
115115
}
116116
} else {
117-
go conn.recvLoop()
117+
conn.recvWg.Add(1)
118+
go func() {
119+
defer conn.recvWg.Done()
120+
conn.recvLoop()
121+
}()
118122
}
119123
conn.sendWg.Add(1)
120124
go func() {
@@ -251,8 +255,7 @@ func (c *xdnsConnClient) sendLoop() {
251255
if p != nil {
252256
var err error
253257
if len(c.resolverConns) > 0 {
254-
idx := c.resolverIdx
255-
c.resolverIdx++
258+
idx := c.resolverIdx.Add(1)
256259
rc := c.resolverConns[idx%uint32(len(c.resolverConns))]
257260
_, err = rc.conn.WriteTo(p.p, rc.addr)
258261
} else {
@@ -312,8 +315,8 @@ func (c *xdnsConnClient) Close() error {
312315
rc.conn.Close()
313316
}
314317
c.closeErr = c.PacketConn.Close()
318+
c.recvWg.Wait()
315319
if len(c.resolverConns) > 0 {
316-
c.recvWg.Wait()
317320
close(c.pollChan)
318321
close(c.readQueue)
319322
c.mutex.Lock()

transport/internet/finalmask/xdns/dns_test.go

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net"
88
"strconv"
99
"strings"
10+
"sync/atomic"
1011
"testing"
1112
"time"
1213
)
@@ -600,26 +601,38 @@ func TestRDataTXTRoundTrip(t *testing.T) {
600601

601602
func TestParseResolverAddr(t *testing.T) {
602603
tests := []struct {
604+
name string
603605
input string
604606
wantIP string
605607
wantPort int
608+
wantErr bool
606609
}{
607-
{"1.1.1.1", "1.1.1.1", 53},
608-
{"8.8.8.8:53", "8.8.8.8", 53},
609-
{"8.8.8.8:5353", "8.8.8.8", 5353},
610-
{"[2606:4700:4700::1111]:53", "2606:4700:4700::1111", 53},
610+
{"bare_ipv4", "1.1.1.1", "1.1.1.1", 53, false},
611+
{"ipv4_with_port", "8.8.8.8:53", "8.8.8.8", 53, false},
612+
{"ipv4_custom_port", "8.8.8.8:5353", "8.8.8.8", 5353, false},
613+
{"ipv6_with_port", "[2606:4700:4700::1111]:53", "2606:4700:4700::1111", 53, false},
614+
{"empty_string", "", "", 0, true},
615+
{"empty_host_with_port", ":53", "", 0, true},
611616
}
612617
for _, tt := range tests {
613-
addr, err := parseResolverAddr(tt.input)
614-
if err != nil {
615-
t.Fatalf("parseResolverAddr(%q): %v", tt.input, err)
616-
}
617-
if addr.IP.String() != tt.wantIP {
618-
t.Errorf("parseResolverAddr(%q).IP = %v, want %v", tt.input, addr.IP, tt.wantIP)
619-
}
620-
if addr.Port != tt.wantPort {
621-
t.Errorf("parseResolverAddr(%q).Port = %v, want %v", tt.input, addr.Port, tt.wantPort)
622-
}
618+
t.Run(tt.name, func(t *testing.T) {
619+
addr, err := parseResolverAddr(tt.input)
620+
if tt.wantErr {
621+
if err == nil {
622+
t.Fatalf("parseResolverAddr(%q): want error, got nil", tt.input)
623+
}
624+
return
625+
}
626+
if err != nil {
627+
t.Fatalf("parseResolverAddr(%q): %v", tt.input, err)
628+
}
629+
if addr.IP.String() != tt.wantIP {
630+
t.Errorf("parseResolverAddr(%q).IP = %v, want %v", tt.input, addr.IP, tt.wantIP)
631+
}
632+
if addr.Port != tt.wantPort {
633+
t.Errorf("parseResolverAddr(%q).Port = %v, want %v", tt.input, addr.Port, tt.wantPort)
634+
}
635+
})
623636
}
624637
}
625638

@@ -716,3 +729,70 @@ func TestResolverModeRoundTrip(t *testing.T) {
716729
t.Errorf("server received %q, want %q", serverBuf[:readN], testPayload)
717730
}
718731
}
732+
733+
func TestMultiResolverDistribution(t *testing.T) {
734+
const numResolvers = 3
735+
736+
// Create mock resolvers that count received packets
737+
resolvers := make([]net.PacketConn, numResolvers)
738+
var counts [numResolvers]atomic.Int32
739+
for i := range resolvers {
740+
r, err := net.ListenPacket("udp", "127.0.0.1:0")
741+
if err != nil {
742+
t.Fatal(err)
743+
}
744+
defer r.Close()
745+
resolvers[i] = r
746+
idx := i
747+
go func() {
748+
buf := make([]byte, 4096)
749+
for {
750+
_, _, err := resolvers[idx].ReadFrom(buf)
751+
if err != nil {
752+
return
753+
}
754+
counts[idx].Add(1)
755+
}
756+
}()
757+
}
758+
759+
resolverAddrs := make([]string, numResolvers)
760+
for i, r := range resolvers {
761+
resolverAddrs[i] = r.LocalAddr().String()
762+
}
763+
764+
config := &Config{
765+
Domain: "t.example.com",
766+
Resolvers: resolverAddrs,
767+
}
768+
rawConn, err := net.ListenPacket("udp", "127.0.0.1:0")
769+
if err != nil {
770+
t.Fatal(err)
771+
}
772+
defer rawConn.Close()
773+
774+
client, err := NewConnClient(config, rawConn)
775+
if err != nil {
776+
t.Fatal(err)
777+
}
778+
defer client.Close()
779+
780+
// Send enough messages to hit all resolvers via round-robin
781+
for i := 0; i < numResolvers*3; i++ {
782+
payload := []byte(fmt.Sprintf("msg-%d", i))
783+
_, err = client.WriteTo(payload, rawConn.LocalAddr())
784+
if err != nil {
785+
t.Fatal(err)
786+
}
787+
}
788+
789+
// Allow sendLoop to process the queue
790+
time.Sleep(500 * time.Millisecond)
791+
792+
for i := 0; i < numResolvers; i++ {
793+
c := counts[i].Load()
794+
if c == 0 {
795+
t.Errorf("resolver %d received no queries, want at least 1", i)
796+
}
797+
}
798+
}

0 commit comments

Comments
 (0)