@@ -796,3 +796,150 @@ func TestMultiResolverDistribution(t *testing.T) {
796796 }
797797 }
798798}
799+
800+ func TestResolverModeRejectsNestedLayers (t * testing.T ) {
801+ config := & Config {
802+ Domain : "t.example.com" ,
803+ Resolvers : []string {"1.1.1.1" },
804+ }
805+
806+ raw , err := net .ListenPacket ("udp" , "127.0.0.1:0" )
807+ if err != nil {
808+ t .Fatal (err )
809+ }
810+ defer raw .Close ()
811+
812+ // level > 0 means there are lower finalmask layers; resolver mode must reject this
813+ _ , err = config .WrapPacketConnClient (raw , 1 , 2 )
814+ if err == nil {
815+ t .Fatal ("expected error for resolver mode with level > 0, got nil" )
816+ }
817+ if ! strings .Contains (err .Error (), "resolver mode" ) {
818+ t .Errorf ("unexpected error message: %v" , err )
819+ }
820+
821+ // level == 0 should succeed
822+ conn , err := config .WrapPacketConnClient (raw , 0 , 1 )
823+ if err != nil {
824+ t .Fatalf ("unexpected error for level=0: %v" , err )
825+ }
826+ conn .Close ()
827+ }
828+
829+ func TestResolverModeServerToClient (t * testing.T ) {
830+ // Auth server: raw UDP socket that the XDNS server wraps
831+ authServer , err := net .ListenPacket ("udp" , "127.0.0.1:0" )
832+ if err != nil {
833+ t .Fatal (err )
834+ }
835+ defer authServer .Close ()
836+
837+ // Mock resolver: bidirectional UDP forwarder
838+ resolver , err := net .ListenPacket ("udp" , "127.0.0.1:0" )
839+ if err != nil {
840+ t .Fatal (err )
841+ }
842+ defer resolver .Close ()
843+
844+ go func () {
845+ buf := make ([]byte , 4096 )
846+ authAddr := authServer .LocalAddr ().String ()
847+ var clientAddr net.Addr
848+ for {
849+ n , addr , err := resolver .ReadFrom (buf )
850+ if err != nil {
851+ return
852+ }
853+ if addr .String () == authAddr {
854+ if clientAddr != nil {
855+ resolver .WriteTo (buf [:n ], clientAddr )
856+ }
857+ } else {
858+ clientAddr = addr
859+ resolver .WriteTo (buf [:n ], authServer .LocalAddr ())
860+ }
861+ }
862+ }()
863+
864+ // XDNS server
865+ serverConfig := & Config {Domain : "t.example.com" }
866+ server , err := NewConnServer (serverConfig , authServer )
867+ if err != nil {
868+ t .Fatal (err )
869+ }
870+ defer server .Close ()
871+
872+ // XDNS client with resolver
873+ config := & Config {
874+ Domain : "t.example.com" ,
875+ Resolvers : []string {resolver .LocalAddr ().String ()},
876+ }
877+ rawConn , err := net .ListenPacket ("udp" , "127.0.0.1:0" )
878+ if err != nil {
879+ t .Fatal (err )
880+ }
881+ defer rawConn .Close ()
882+
883+ client , err := NewConnClient (config , rawConn )
884+ if err != nil {
885+ t .Fatal (err )
886+ }
887+ defer client .Close ()
888+
889+ // Client sends a query to trigger the connection and set serverAddr
890+ _ , err = client .WriteTo ([]byte ("init" ), rawConn .LocalAddr ())
891+ if err != nil {
892+ t .Fatal (err )
893+ }
894+
895+ // Wait for server to receive the client query
896+ serverBuf := make ([]byte , 256 )
897+ done := make (chan struct {})
898+ var serverReadAddr net.Addr
899+ go func () {
900+ defer close (done )
901+ _ , serverReadAddr , _ = server .ReadFrom (serverBuf )
902+ }()
903+ select {
904+ case <- done :
905+ case <- time .After (5 * time .Second ):
906+ t .Fatal ("timeout waiting for server ReadFrom" )
907+ }
908+
909+ // Server writes data back to the client
910+ responsePayload := []byte ("hello from server" )
911+ _ , err = server .WriteTo (responsePayload , serverReadAddr )
912+ if err != nil {
913+ t .Fatalf ("server WriteTo: %v" , err )
914+ }
915+
916+ // Client sends another query to trigger the server to send the response
917+ // (server data is delivered as DNS response payloads)
918+ _ , err = client .WriteTo ([]byte ("poll" ), rawConn .LocalAddr ())
919+ if err != nil {
920+ t .Fatal (err )
921+ }
922+
923+ // Read from client with timeout
924+ clientBuf := make ([]byte , 256 )
925+ done2 := make (chan struct {})
926+ var clientReadN int
927+ var clientReadErr error
928+ go func () {
929+ defer close (done2 )
930+ clientReadN , _ , clientReadErr = client .ReadFrom (clientBuf )
931+ }()
932+
933+ select {
934+ case <- done2 :
935+ case <- time .After (5 * time .Second ):
936+ t .Fatal ("timeout waiting for client ReadFrom" )
937+ }
938+
939+ if clientReadErr != nil {
940+ t .Fatalf ("client ReadFrom: %v" , clientReadErr )
941+ }
942+ if ! bytes .Equal (clientBuf [:clientReadN ], responsePayload ) {
943+ t .Errorf ("client received %q, want %q" , clientBuf [:clientReadN ], responsePayload )
944+ }
945+ }
0 commit comments