Skip to content

Commit 4cffe15

Browse files
committed
tcp,udp: set rx, tx summary for dns overriden sockets
1 parent 3e82dca commit 4cffe15

2 files changed

Lines changed: 49 additions & 30 deletions

File tree

intra/common.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,12 @@ func (h *baseHandler) dnsOverride(conn net.Conn, uid string, smm *SocketSummary)
434434
// SocketSummary is not meant to be used by the listener; x.DNSSummary is
435435
// but call into PostFlow & OnSocketClosed anyway, to avoid ambiguities
436436
// on which sockets / sessions are still active.
437-
h.resolver.Serve(h.proto, conn, uid, func() { h.queueSummary(smm.done()) })
437+
rx, tx, errs := h.resolver.Serve(h.proto, conn, uid)
438+
smm.Rx = rx
439+
smm.Tx = tx
440+
// smm.Rtt
441+
// smm.Target = DNS resolver?
442+
h.listener.OnSocketClosed(smm.done(errs...))
438443
})
439444
return true
440445
}

intra/dnsx/transport.go

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ type Resolver interface {
176176
// IsDnsAddr returns true if the ip:port is resolver's fake endpoint
177177
IsDnsAddr(ipport netip.AddrPort) bool
178178
// Serve reads DNS query from conn and writes DNS answer to conn
179-
Serve(proto string, conn protect.Conn, uid string, cb core.Finally)
179+
Serve(proto string, conn protect.Conn, uid string) (rx, tx int64, errs []error)
180180

181181
// StopAll stops all transports.
182182
StopAll()
@@ -734,11 +734,10 @@ runagain:
734734
}
735735

736736
// Serve implements Resolver.
737-
func (r *resolver) Serve(proto string, c protect.Conn, uid string, cb core.Finally) {
738-
defer cb()
739-
737+
func (r *resolver) Serve(proto string, c protect.Conn, uid string) (rx, tx int64, errs []error) {
740738
if r.closed.Load() {
741-
log.W("dns: serve: closed for business")
739+
err := log.EE("dns: serve: closed for business")
740+
errs = append(errs, err)
742741
return
743742
}
744743

@@ -751,12 +750,14 @@ func (r *resolver) Serve(proto string, c protect.Conn, uid string, cb core.Final
751750

752751
switch proto {
753752
case NetTypeTCP:
754-
r.accept(c, uid)
753+
rx, tx, errs = r.accept(c, uid)
755754
case NetTypeUDP:
756-
r.reply(c, uid)
755+
rx, tx, errs = r.reply(c, uid)
757756
default:
758-
log.W("dns: unknown proto: %s", proto)
757+
err := log.EE("dns: unknown proto: %s", proto)
758+
errs = append(errs, err)
759759
}
760+
return
760761
}
761762

762763
func (r *resolver) determineTransport(id string) Transport {
@@ -822,50 +823,49 @@ func (r *resolver) determineTransport(id string) Transport {
822823
}
823824

824825
// dnstcp queries the transport and writes answers to w, prefixed by length.
825-
func (r *resolver) dnstcp(q []byte, w io.WriteCloser, uid string) error {
826+
func (r *resolver) dnstcp(q []byte, w io.WriteCloser, uid string) (written int, err error) {
826827
ans, _, err := r.forward(q, OriginTunnel, uid)
827828

828829
rlen := len(ans)
829830
if rlen <= 0 && err != nil {
830831
clos(w) // close on client err
831-
return err
832+
return
832833
}
833834

834-
if n, err := writePrefixed(w, ans, rlen); err != nil {
835+
if written, err = writePrefixed(w, ans, rlen); err != nil {
835836
clos(w) // close on write back err
836-
return err
837-
} else if n != rlen {
838-
// do not close on incomplete writes
839-
return fmt.Errorf("dns: tcp: for %s incomplete write: n(%d) != r(%d)", uid, n, rlen)
837+
} else if written != rlen { // do not close on incomplete writes
838+
err = fmt.Errorf("dns: tcp: for %s incomplete write: n(%d) != r(%d)", uid, written, rlen)
840839
}
841-
return nil // ok
840+
return
842841
}
843842

844843
// dnsudp queries the transport and writes answers to w.
845-
func (r *resolver) dnsudp(q []byte, w io.WriteCloser, uid string) error {
844+
func (r *resolver) dnsudp(q []byte, w io.WriteCloser, uid string) (written int, err error) {
846845
ans, _, err := r.forward(q, OriginTunnel, uid)
847846

848847
rlen := len(ans)
849848
if rlen <= 0 && err != nil {
850849
clos(w) // close on client err
851-
return err
850+
return
852851
}
853852

854-
if n, err := w.Write(ans); err != nil {
853+
if written, err = w.Write(ans); err != nil {
855854
clos(w) // close on write back err
856-
return err
857-
} else if n != rlen {
855+
} else if written != rlen {
858856
// do not close on incomplete writes
859-
return fmt.Errorf("dns: udp: for %s incomplete write: n(%d) != r(%d)", uid, n, rlen)
857+
err = fmt.Errorf("dns: udp: for %s incomplete write: n(%d) != r(%d)", uid, written, rlen)
860858
}
861859

862-
return nil // ok
860+
return
863861
}
864862

865863
// reply DNS-over-UDP from a stub resolver.
866-
func (r *resolver) reply(c protect.Conn, uid string) {
864+
func (r *resolver) reply(c protect.Conn, uid string) (rx, tx int64, errs []error) {
867865
defer clos(c)
868866

867+
var rxv, txv atomic.Int64
868+
869869
var wg sync.WaitGroup
870870
start := time.Now()
871871
cnt := 0
@@ -891,22 +891,30 @@ func (r *resolver) reply(c protect.Conn, uid string) {
891891
wg.Add(1)
892892
defer wg.Done()
893893
defer free()
894-
err = r.dnsudp(q[:n], c, uid)
894+
m, err := r.dnsudp(q[:n], c, uid)
895895
logeif(err != nil)("dns: udp: for %s err! tot: %d, t: %s, %v",
896896
uid, cnt, core.FmtTimeAsPeriod(start), err)
897+
rxv.Add(int64(m))
898+
txv.Add(int64(n))
899+
errs = append(errs, err)
897900
})
898901
}
899902
cnt++
900903
}
901904
wg.Wait()
902-
log.VV("dns: udp: for %s done; tot: %d, t: %s", uid, cnt, core.FmtTimeAsPeriod(start))
905+
rx = rxv.Load()
906+
tx = txv.Load()
907+
log.VV("dns: udp: for %s done; tot: %d (rx: %d, tx: %d), t: %s", uid, cnt, rx, tx, core.FmtTimeAsPeriod(start))
908+
return
903909
}
904910

905911
// Accept a DNS-over-TCP socket from a stub resolver, and connect the socket
906912
// to this DNSTransport.
907-
func (r *resolver) accept(c io.ReadWriteCloser, uid string) {
913+
func (r *resolver) accept(c io.ReadWriteCloser, uid string) (rx, tx int64, errs []error) {
908914
defer clos(c)
909915

916+
var rxv, txv atomic.Int64
917+
910918
var wg sync.WaitGroup
911919
start := time.Now()
912920
cnt := 0
@@ -952,15 +960,21 @@ func (r *resolver) accept(c io.ReadWriteCloser, uid string) {
952960
wg.Add(1)
953961
defer wg.Done()
954962
defer free()
955-
err = r.dnstcp(q[:n], c, uid)
963+
m, err := r.dnstcp(q[:n], c, uid)
956964
logeif(err != nil)("dns: tcp: for %s err! tot: %d, t: %s, %v",
957965
uid, cnt, core.FmtTimeAsPeriod(start), err)
966+
errs = append(errs, err)
967+
txv.Add(int64(n))
968+
rxv.Add(int64(m))
958969
})
959970
cnt++
960971
}
961972
wg.Wait()
962-
log.VV("dns: tcp: for %s done; tot: %d, t: %s", uid, cnt, core.FmtTimeAsPeriod(start))
973+
rx = rxv.Load()
974+
tx = txv.Load()
975+
log.VV("dns: tcp: for %s done; tot: %d (rx: %d, tx: %d), t: %s", uid, cnt, rx, tx, core.FmtTimeAsPeriod(start))
963976
// TODO: Cancel outstanding queries.
977+
return
964978
}
965979

966980
// StopAll implements TransportMult.

0 commit comments

Comments
 (0)