Skip to content

Commit 3baa787

Browse files
committed
ipn/wg: TNT status to override them all
1 parent 69b09c2 commit 3baa787

2 files changed

Lines changed: 43 additions & 36 deletions

File tree

intra/ipn/wg/wgconn.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ type PktDir string
102102
const (
103103
Rcv PktDir = "recv"
104104
Snd PktDir = "send"
105+
Con PktDir = "conn" // e.g. dial, announce, accept
106+
Opn PktDir = "open" // open conn to the wg endpoint
105107
)
106108

107109
type StdNetBind struct {

intra/ipn/wgproxy.go

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,9 +1110,8 @@ func (h *wgtun) Dial(network, address string) (c net.Conn, err error) {
11101110
log.D("wg: %s dial: start %s %s", h.tag(), network, address)
11111111

11121112
// DialContext resolves addr if needed; then dialing into all resolved ips.
1113-
if c, err = h.DialContext(context.TODO(), network, address); err != nil {
1114-
h.status.Store(TKO)
1115-
} // else: status updated by h.listener
1113+
c, err = h.DialContext(context.TODO(), network, address)
1114+
defer h.listener(wg.Con, err) // status updated by h.listener
11161115

11171116
log.I("wg: %s dial: end %s %s; err %v", h.tag(), network, address, err)
11181117
return
@@ -1128,11 +1127,10 @@ func (h *wgtun) DialBind(network, local, remote string) (c net.Conn, err error)
11281127
log.D("wg: %s dialbind: start %s %s=>%s", h.tag(), network, local, remote)
11291128

11301129
// DialContext resolves addr if needed; then dialing into all resolved ips.
1131-
if c, err = h.DialContext(context.TODO(), network, remote); err != nil {
1132-
h.status.Store(TKO)
1133-
} // else: status updated by h.listener
1130+
c, err = h.DialContext(context.TODO(), network, remote)
1131+
defer h.listener(wg.Con, err) // status updated by h.listener when creating conns
11341132

1135-
log.I("wg: %s dial: end %s %s; err %v", h.tag(), network, remote, err)
1133+
log.I("wg: %s dialbind: end %s %s=>%s; err %v", h.tag(), network, local, remote, err)
11361134
return
11371135
}
11381136

@@ -1147,9 +1145,8 @@ func (h *wgtun) Announce(network, local string) (pc net.PacketConn, err error) {
11471145

11481146
var addr netip.AddrPort
11491147
if addr, err = netip.ParseAddrPort(local); err == nil {
1150-
if pc, err = h.ListenUDPAddrPort(addr); err != nil {
1151-
h.status.Store(TKO)
1152-
} // else: status updated by h.listener
1148+
pc, err = h.ListenUDPAddrPort(addr)
1149+
defer h.listener(wg.Con, err)
11531150
} // else: expect local to always be ipaddr
11541151

11551152
log.I("wg: %s announce: end %s %s; err %v", h.tag(), network, local, err)
@@ -1167,9 +1164,8 @@ func (h *wgtun) Accept(network, local string) (ln net.Listener, err error) {
11671164

11681165
var addr netip.AddrPort
11691166
if addr, err = netip.ParseAddrPort(local); err == nil {
1170-
if ln, err = h.ListenTCPAddrPort(addr); err != nil {
1171-
h.status.Store(TKO)
1172-
} // else: status updated by h.listener
1167+
ln, err = h.ListenTCPAddrPort(addr)
1168+
defer h.listener(wg.Con, err)
11731169
} // else: expect local to always be ipaddr
11741170

11751171
log.I("wg: %s accept: end %s %s; err %v", h.tag(), network, local, err)
@@ -1187,9 +1183,8 @@ func (h *wgtun) Probe(network, local string) (pc net.PacketConn, err error) {
11871183

11881184
var addr netip.AddrPort
11891185
if addr, err = netip.ParseAddrPort(local); err == nil {
1190-
if pc, err = h.ListenUDPAddrPort(addr); err != nil {
1191-
h.status.Store(TKO)
1192-
} // else: status updated by h.listener
1186+
pc, err = h.ListenUDPAddrPort(addr)
1187+
defer h.listener(wg.Con, err)
11931188
} // else: expect local to always be ipaddr
11941189

11951190
log.I("wg: %s probe: end %s %s; err %v", h.tag(), network, local, err)
@@ -1373,50 +1368,60 @@ func (h *wgtun) serve(network, local string) (pc net.PacketConn, err error) {
13731368
pc, err = h.direct.Announce(network, local)
13741369
}
13751370
h.viaUp.Store(usingvia)
1376-
defer localDialStatus(h.status, err)
1371+
defer h.listener(wg.Opn, err)
13771372

13781373
logei(err)("wg: %s serve: %s (id? %s / via? %s / usingVia? %t); err? %v",
13791374
h.id, local, who, idstr(v), usingvia, err)
13801375
return
13811376
}
13821377

13831378
func (h *wgtun) listener(op wg.PktDir, err error) {
1384-
if h.status.Load() == END {
1379+
s := h.status.Load()
1380+
if s == END {
1381+
return
1382+
}
1383+
1384+
if s == TUP && op != wg.Opn { // ignore all else but open
13851385
return
13861386
}
13871387

1388-
s := TOK // assume err == nil
1389-
if op == wg.Rcv && timedout(err) {
1388+
if err != nil { // failing
1389+
s = TKO
1390+
if op == wg.Opn { // could not open conn to wg endpoint
1391+
s = TNT
1392+
}
1393+
if op == wg.Rcv && timedout(err) {
1394+
s = TZZ // wirtes and reads have succeeded in the recent past
1395+
}
1396+
} else { // ok
1397+
s = TOK
1398+
if op == wg.Rcv { // read ok
1399+
h.latestRx.Store(now())
1400+
} else if op == wg.Snd { // write ok
1401+
h.latestTx.Store(now())
1402+
}
1403+
}
1404+
1405+
if s != TNT {
13901406
lastSuccessfulRead := h.latestRx.Load()
13911407
writeElapsedMs := h.latestTx.Load() - lastSuccessfulRead // may be negative
1408+
// if no reads since last write, mark as unresponsive
13921409
// if status is "up" but writes (Snd) have not yet happened
13931410
// then reads (Rcv) are expected to timeout; so ignore them
13941411
if lastSuccessfulRead <= 0 || writeElapsedMs > markTNTAfterMillis {
13951412
s = TNT // writes succeeded; but reads have never or not in the past 20s
1396-
} else {
1397-
s = TZZ // wirtes and reads have succeeded in the recent past
13981413
}
1399-
} else if err != nil {
1400-
s = TKO // failing
14011414
}
14021415

1403-
if s == TOK {
1404-
if op == wg.Rcv { // read
1405-
h.latestRx.Store(now())
1406-
} else if op == wg.Snd { // write
1407-
h.latestTx.Store(now())
1408-
}
1409-
writeElapsedMs := h.latestTx.Load() - h.latestRx.Load() // may be negative
1410-
// if no reads since last write, mark as unresponsive
1411-
if writeElapsedMs > markTNTAfterMillis {
1412-
s = TNT
1413-
}
1414-
} else if s != TUP {
1416+
if s != TOK {
14151417
if op == wg.Rcv {
14161418
h.errRx.Add(1)
14171419
} else if op == wg.Snd {
14181420
h.errTx.Add(1)
14191421
}
1422+
}
1423+
1424+
if s == TNT {
14201425
if n := h.remote.Load().MaybeRefresh(); n > 0 {
14211426
log.I("wg: %s listener: %s, state: %s; refreshed n domains: %d",
14221427
h.tag(), op, pxstatus(s), n)

0 commit comments

Comments
 (0)