@@ -176,7 +176,7 @@ type Resolver interface {
176176 // IsDnsAddr returns true if the ip:port is resolver's fake endpoint
177177 IsDnsAddr (ipport netip.AddrPort ) bool
178178 // Serve reads DNS query from conn and writes DNS answer to conn
179- Serve (proto string , conn protect.Conn , uid string , cb core. Finally )
179+ Serve (proto string , conn protect.Conn , uid string ) ( rx , tx int64 , errs [] error )
180180
181181 // StopAll stops all transports.
182182 StopAll ()
@@ -734,11 +734,10 @@ runagain:
734734}
735735
736736// Serve implements Resolver.
737- func (r * resolver ) Serve (proto string , c protect.Conn , uid string , cb core.Finally ) {
738- defer cb ()
739-
737+ func (r * resolver ) Serve (proto string , c protect.Conn , uid string ) (rx , tx int64 , errs []error ) {
740738 if r .closed .Load () {
741- log .W ("dns: serve: closed for business" )
739+ err := log .EE ("dns: serve: closed for business" )
740+ errs = append (errs , err )
742741 return
743742 }
744743
@@ -751,12 +750,14 @@ func (r *resolver) Serve(proto string, c protect.Conn, uid string, cb core.Final
751750
752751 switch proto {
753752 case NetTypeTCP :
754- r .accept (c , uid )
753+ rx , tx , errs = r .accept (c , uid )
755754 case NetTypeUDP :
756- r .reply (c , uid )
755+ rx , tx , errs = r .reply (c , uid )
757756 default :
758- log .W ("dns: unknown proto: %s" , proto )
757+ err := log .EE ("dns: unknown proto: %s" , proto )
758+ errs = append (errs , err )
759759 }
760+ return
760761}
761762
762763func (r * resolver ) determineTransport (id string ) Transport {
@@ -822,50 +823,49 @@ func (r *resolver) determineTransport(id string) Transport {
822823}
823824
824825// dnstcp queries the transport and writes answers to w, prefixed by length.
825- func (r * resolver ) dnstcp (q []byte , w io.WriteCloser , uid string ) error {
826+ func (r * resolver ) dnstcp (q []byte , w io.WriteCloser , uid string ) ( written int , err error ) {
826827 ans , _ , err := r .forward (q , OriginTunnel , uid )
827828
828829 rlen := len (ans )
829830 if rlen <= 0 && err != nil {
830831 clos (w ) // close on client err
831- return err
832+ return
832833 }
833834
834- if n , err : = writePrefixed (w , ans , rlen ); err != nil {
835+ if written , err = writePrefixed (w , ans , rlen ); err != nil {
835836 clos (w ) // close on write back err
836- return err
837- } else if n != rlen {
838- // do not close on incomplete writes
839- return fmt .Errorf ("dns: tcp: for %s incomplete write: n(%d) != r(%d)" , uid , n , rlen )
837+ } else if written != rlen { // do not close on incomplete writes
838+ err = fmt .Errorf ("dns: tcp: for %s incomplete write: n(%d) != r(%d)" , uid , written , rlen )
840839 }
841- return nil // ok
840+ return
842841}
843842
844843// dnsudp queries the transport and writes answers to w.
845- func (r * resolver ) dnsudp (q []byte , w io.WriteCloser , uid string ) error {
844+ func (r * resolver ) dnsudp (q []byte , w io.WriteCloser , uid string ) ( written int , err error ) {
846845 ans , _ , err := r .forward (q , OriginTunnel , uid )
847846
848847 rlen := len (ans )
849848 if rlen <= 0 && err != nil {
850849 clos (w ) // close on client err
851- return err
850+ return
852851 }
853852
854- if n , err : = w .Write (ans ); err != nil {
853+ if written , err = w .Write (ans ); err != nil {
855854 clos (w ) // close on write back err
856- return err
857- } else if n != rlen {
855+ } else if written != rlen {
858856 // do not close on incomplete writes
859- return fmt .Errorf ("dns: udp: for %s incomplete write: n(%d) != r(%d)" , uid , n , rlen )
857+ err = fmt .Errorf ("dns: udp: for %s incomplete write: n(%d) != r(%d)" , uid , written , rlen )
860858 }
861859
862- return nil // ok
860+ return
863861}
864862
865863// reply DNS-over-UDP from a stub resolver.
866- func (r * resolver ) reply (c protect.Conn , uid string ) {
864+ func (r * resolver ) reply (c protect.Conn , uid string ) ( rx , tx int64 , errs [] error ) {
867865 defer clos (c )
868866
867+ var rxv , txv atomic.Int64
868+
869869 var wg sync.WaitGroup
870870 start := time .Now ()
871871 cnt := 0
@@ -891,22 +891,30 @@ func (r *resolver) reply(c protect.Conn, uid string) {
891891 wg .Add (1 )
892892 defer wg .Done ()
893893 defer free ()
894- err = r .dnsudp (q [:n ], c , uid )
894+ m , err : = r .dnsudp (q [:n ], c , uid )
895895 logeif (err != nil )("dns: udp: for %s err! tot: %d, t: %s, %v" ,
896896 uid , cnt , core .FmtTimeAsPeriod (start ), err )
897+ rxv .Add (int64 (m ))
898+ txv .Add (int64 (n ))
899+ errs = append (errs , err )
897900 })
898901 }
899902 cnt ++
900903 }
901904 wg .Wait ()
902- log .VV ("dns: udp: for %s done; tot: %d, t: %s" , uid , cnt , core .FmtTimeAsPeriod (start ))
905+ rx = rxv .Load ()
906+ tx = txv .Load ()
907+ log .VV ("dns: udp: for %s done; tot: %d (rx: %d, tx: %d), t: %s" , uid , cnt , rx , tx , core .FmtTimeAsPeriod (start ))
908+ return
903909}
904910
905911// Accept a DNS-over-TCP socket from a stub resolver, and connect the socket
906912// to this DNSTransport.
907- func (r * resolver ) accept (c io.ReadWriteCloser , uid string ) {
913+ func (r * resolver ) accept (c io.ReadWriteCloser , uid string ) ( rx , tx int64 , errs [] error ) {
908914 defer clos (c )
909915
916+ var rxv , txv atomic.Int64
917+
910918 var wg sync.WaitGroup
911919 start := time .Now ()
912920 cnt := 0
@@ -952,15 +960,21 @@ func (r *resolver) accept(c io.ReadWriteCloser, uid string) {
952960 wg .Add (1 )
953961 defer wg .Done ()
954962 defer free ()
955- err = r .dnstcp (q [:n ], c , uid )
963+ m , err : = r .dnstcp (q [:n ], c , uid )
956964 logeif (err != nil )("dns: tcp: for %s err! tot: %d, t: %s, %v" ,
957965 uid , cnt , core .FmtTimeAsPeriod (start ), err )
966+ errs = append (errs , err )
967+ txv .Add (int64 (n ))
968+ rxv .Add (int64 (m ))
958969 })
959970 cnt ++
960971 }
961972 wg .Wait ()
962- log .VV ("dns: tcp: for %s done; tot: %d, t: %s" , uid , cnt , core .FmtTimeAsPeriod (start ))
973+ rx = rxv .Load ()
974+ tx = txv .Load ()
975+ log .VV ("dns: tcp: for %s done; tot: %d (rx: %d, tx: %d), t: %s" , uid , cnt , rx , tx , core .FmtTimeAsPeriod (start ))
963976 // TODO: Cancel outstanding queries.
977+ return
964978}
965979
966980// StopAll implements TransportMult.
0 commit comments