|
7 | 7 | "net" |
8 | 8 | "strconv" |
9 | 9 | "strings" |
| 10 | + "sync/atomic" |
10 | 11 | "testing" |
11 | 12 | "time" |
12 | 13 | ) |
@@ -600,26 +601,38 @@ func TestRDataTXTRoundTrip(t *testing.T) { |
600 | 601 |
|
601 | 602 | func TestParseResolverAddr(t *testing.T) { |
602 | 603 | tests := []struct { |
| 604 | + name string |
603 | 605 | input string |
604 | 606 | wantIP string |
605 | 607 | wantPort int |
| 608 | + wantErr bool |
606 | 609 | }{ |
607 | | - {"1.1.1.1", "1.1.1.1", 53}, |
608 | | - {"8.8.8.8:53", "8.8.8.8", 53}, |
609 | | - {"8.8.8.8:5353", "8.8.8.8", 5353}, |
610 | | - {"[2606:4700:4700::1111]:53", "2606:4700:4700::1111", 53}, |
| 610 | + {"bare_ipv4", "1.1.1.1", "1.1.1.1", 53, false}, |
| 611 | + {"ipv4_with_port", "8.8.8.8:53", "8.8.8.8", 53, false}, |
| 612 | + {"ipv4_custom_port", "8.8.8.8:5353", "8.8.8.8", 5353, false}, |
| 613 | + {"ipv6_with_port", "[2606:4700:4700::1111]:53", "2606:4700:4700::1111", 53, false}, |
| 614 | + {"empty_string", "", "", 0, true}, |
| 615 | + {"empty_host_with_port", ":53", "", 0, true}, |
611 | 616 | } |
612 | 617 | for _, tt := range tests { |
613 | | - addr, err := parseResolverAddr(tt.input) |
614 | | - if err != nil { |
615 | | - t.Fatalf("parseResolverAddr(%q): %v", tt.input, err) |
616 | | - } |
617 | | - if addr.IP.String() != tt.wantIP { |
618 | | - t.Errorf("parseResolverAddr(%q).IP = %v, want %v", tt.input, addr.IP, tt.wantIP) |
619 | | - } |
620 | | - if addr.Port != tt.wantPort { |
621 | | - t.Errorf("parseResolverAddr(%q).Port = %v, want %v", tt.input, addr.Port, tt.wantPort) |
622 | | - } |
| 618 | + t.Run(tt.name, func(t *testing.T) { |
| 619 | + addr, err := parseResolverAddr(tt.input) |
| 620 | + if tt.wantErr { |
| 621 | + if err == nil { |
| 622 | + t.Fatalf("parseResolverAddr(%q): want error, got nil", tt.input) |
| 623 | + } |
| 624 | + return |
| 625 | + } |
| 626 | + if err != nil { |
| 627 | + t.Fatalf("parseResolverAddr(%q): %v", tt.input, err) |
| 628 | + } |
| 629 | + if addr.IP.String() != tt.wantIP { |
| 630 | + t.Errorf("parseResolverAddr(%q).IP = %v, want %v", tt.input, addr.IP, tt.wantIP) |
| 631 | + } |
| 632 | + if addr.Port != tt.wantPort { |
| 633 | + t.Errorf("parseResolverAddr(%q).Port = %v, want %v", tt.input, addr.Port, tt.wantPort) |
| 634 | + } |
| 635 | + }) |
623 | 636 | } |
624 | 637 | } |
625 | 638 |
|
@@ -716,3 +729,70 @@ func TestResolverModeRoundTrip(t *testing.T) { |
716 | 729 | t.Errorf("server received %q, want %q", serverBuf[:readN], testPayload) |
717 | 730 | } |
718 | 731 | } |
| 732 | + |
| 733 | +func TestMultiResolverDistribution(t *testing.T) { |
| 734 | + const numResolvers = 3 |
| 735 | + |
| 736 | + // Create mock resolvers that count received packets |
| 737 | + resolvers := make([]net.PacketConn, numResolvers) |
| 738 | + var counts [numResolvers]atomic.Int32 |
| 739 | + for i := range resolvers { |
| 740 | + r, err := net.ListenPacket("udp", "127.0.0.1:0") |
| 741 | + if err != nil { |
| 742 | + t.Fatal(err) |
| 743 | + } |
| 744 | + defer r.Close() |
| 745 | + resolvers[i] = r |
| 746 | + idx := i |
| 747 | + go func() { |
| 748 | + buf := make([]byte, 4096) |
| 749 | + for { |
| 750 | + _, _, err := resolvers[idx].ReadFrom(buf) |
| 751 | + if err != nil { |
| 752 | + return |
| 753 | + } |
| 754 | + counts[idx].Add(1) |
| 755 | + } |
| 756 | + }() |
| 757 | + } |
| 758 | + |
| 759 | + resolverAddrs := make([]string, numResolvers) |
| 760 | + for i, r := range resolvers { |
| 761 | + resolverAddrs[i] = r.LocalAddr().String() |
| 762 | + } |
| 763 | + |
| 764 | + config := &Config{ |
| 765 | + Domain: "t.example.com", |
| 766 | + Resolvers: resolverAddrs, |
| 767 | + } |
| 768 | + rawConn, err := net.ListenPacket("udp", "127.0.0.1:0") |
| 769 | + if err != nil { |
| 770 | + t.Fatal(err) |
| 771 | + } |
| 772 | + defer rawConn.Close() |
| 773 | + |
| 774 | + client, err := NewConnClient(config, rawConn) |
| 775 | + if err != nil { |
| 776 | + t.Fatal(err) |
| 777 | + } |
| 778 | + defer client.Close() |
| 779 | + |
| 780 | + // Send enough messages to hit all resolvers via round-robin |
| 781 | + for i := 0; i < numResolvers*3; i++ { |
| 782 | + payload := []byte(fmt.Sprintf("msg-%d", i)) |
| 783 | + _, err = client.WriteTo(payload, rawConn.LocalAddr()) |
| 784 | + if err != nil { |
| 785 | + t.Fatal(err) |
| 786 | + } |
| 787 | + } |
| 788 | + |
| 789 | + // Allow sendLoop to process the queue |
| 790 | + time.Sleep(500 * time.Millisecond) |
| 791 | + |
| 792 | + for i := 0; i < numResolvers; i++ { |
| 793 | + c := counts[i].Load() |
| 794 | + if c == 0 { |
| 795 | + t.Errorf("resolver %d received no queries, want at least 1", i) |
| 796 | + } |
| 797 | + } |
| 798 | +} |
0 commit comments