Skip to content

Commit 5589498

Browse files
committed
tunnel: use context to cancel awaiter
1 parent a42d165 commit 5589498

2 files changed

Lines changed: 35 additions & 18 deletions

File tree

intra/core/async.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,16 @@ func Every(id string, pctx context.Context, d time.Duration, f func()) context.C
212212
return ctx
213213
}
214214

215+
// SigFin runs f in a goroutine and returns a channel that is closed when f returns.
216+
func SigFin(f func()) <-chan struct{} {
217+
done := make(chan struct{})
218+
Go("take", func() {
219+
defer close(done)
220+
f()
221+
})
222+
return done
223+
}
224+
215225
func Await(f func(), until time.Duration) (awaited bool) {
216226
done := make(chan struct{})
217227
Go("await", func() {

tunnel/tunnel.go

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func (t *gtunnel) Mtu() int32 {
9999
return -1
100100
}
101101

102-
func (t *gtunnel) waitForEndpoint() {
102+
func (t *gtunnel) waitForEndpoint(ctx context.Context) {
103103
defer core.Recover(core.Exit11, "g.wait")
104104

105105
const maxchecks = 5
@@ -108,12 +108,21 @@ func (t *gtunnel) waitForEndpoint() {
108108

109109
waitStart := time.Now()
110110
i := 0
111+
112+
defer log.I("tun: waiter: done; #%d, %s", i, core.FmtTimeAsPeriod(waitStart))
113+
111114
for i < maxchecks && !t.closed.Load() {
112115
// wait a bit to let the endpoint settle
113116
time.Sleep(betweenChecks)
114117
start := time.Now()
115118

116-
t.ep.Wait() // wait until endpoint closes
119+
select {
120+
case <-ctx.Done():
121+
log.D("tun: waiter: ctx done; #%d", i)
122+
i = maxchecks // exit loop
123+
case <-core.SigFin(t.ep.Wait): // wait until endpoint closes
124+
log.D("tun: waiter: endpoint not running; #%d", i)
125+
}
117126

118127
// if the endpoint was up for more than uptimeThreshold,
119128
// reset the counter and do another set of maxchecks
@@ -128,18 +137,14 @@ func (t *gtunnel) waitForEndpoint() {
128137
i++
129138
}
130139
}
131-
waitDone := int64(time.Since(waitStart).Milliseconds() / 1000)
132-
140+
waitDone := core.FmtTimeAsPeriod(waitStart)
133141
if !t.closed.Load() {
134142
// the endpoint closed without a Disconnect, this may happen
135143
// in cases where a panic was recovered and endpoint was
136144
// closed without a t.ep.Swap or t.stack.Destroy
137-
log.E("tun: waiter: ep notified close; #%d, %dsecs", i, waitDone)
138-
log.U(fmt.Sprintf("Deactivated! Down after %dsecs", waitDone))
145+
log.U(fmt.Sprintf("Deactivated! Down after %s", waitDone))
139146
// todo: disconnect parent tunnel
140147
t.Disconnect() // may already be disconnected
141-
} else {
142-
log.D("tun: waiter: done; #%d, %dsecs", i, waitDone)
143148
}
144149
}
145150

@@ -179,11 +184,14 @@ func NewGTunnel(pctx context.Context, fd, mtu int, l3 string, hdl netstack.GConn
179184
return nil, nil, err
180185
}
181186

182-
sink := newSink(pctx)
187+
ctx, done := context.WithCancel(pctx)
188+
189+
sink := newSink(ctx)
183190
stack := netstack.NewNetstack() // always dual-stack
184191
// NewEndpoint takes ownership of dupfd; closes it on errors
185192
ep, eerr := netstack.NewEndpoint(dupfd, mtu, sink)
186193
if eerr != nil {
194+
done()
187195
return nil, nil, eerr
188196
}
189197

@@ -197,15 +205,16 @@ func NewGTunnel(pctx context.Context, fd, mtu int, l3 string, hdl netstack.GConn
197205
var nic tcpip.NICID
198206
// Enabled() may temporarily return false when Up() is in progress.
199207
if nic, err = netstack.Up(who, stack, ep, hdl); err != nil { // attach new endpoint
208+
done()
200209
return nil, nil, err
201210
}
202211

203-
rev = netstack.NewReverseGConnHandler(who, pctx, stack, nic, ep, hdl)
212+
rev = netstack.NewReverseGConnHandler(who, ctx, stack, nic, ep, hdl)
204213

205214
log.I("tun: new netstack(%d) up; fd(%d=>%d), mtu(%d)", nic, fd, dupfd, mtu)
206215

207216
t = &gtunnel{
208-
ctx: pctx,
217+
ctx: ctx,
209218
stack: stack,
210219
ep: ep,
211220
sid: core.NewVolatile(fd), // fd is the og tun device
@@ -214,13 +223,11 @@ func NewGTunnel(pctx context.Context, fd, mtu int, l3 string, hdl netstack.GConn
214223
closed: atomic.Bool{},
215224
once: sync.Once{},
216225
}
217-
go t.waitForEndpoint()
218-
context.AfterFunc(pctx, func() {
219-
log.I("tun: ctx done")
220-
if !t.closed.Load() {
221-
t.Disconnect()
222-
}
223-
})
226+
227+
go func() {
228+
defer done()
229+
t.waitForEndpoint(ctx)
230+
}()
224231
return
225232
}
226233

0 commit comments

Comments
 (0)