Skip to content

Commit a5e2b57

Browse files
committed
netstack: swap may create a new fdbased endpoint
1 parent 994f58c commit a5e2b57

2 files changed

Lines changed: 42 additions & 4 deletions

File tree

intra/netstack/fdbased.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
package netstack
3131

3232
import (
33+
"errors"
3334
"fmt"
3435
"sync/atomic"
3536
"time"
@@ -59,6 +60,8 @@ const errorOnInvalidFD = false
5960
// wrapttl is the time to wait for the dispatcher to wrap up (close a previous FD).
6061
const waitttl = wrapttl
6162

63+
var errNeedsNewEndpoint = errors.New("ns: needs new endpoint")
64+
6265
type FdSwapper interface {
6366
// Swap closes existing FDs; uses new fd.
6467
Swap(fd int) error
@@ -200,7 +203,7 @@ type Options struct {
200203
// Makes fd non-blocking, but does not take ownership of fd, which must remain
201204
// open for the lifetime of the returned endpoint (until after the endpoint has
202205
// stopped being using and Wait returns).
203-
func NewFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
206+
func newFdbasedInjectableEndpoint(opts *Options) (SeamlessEndpoint, error) {
204207
caps := stack.LinkEndpointCapabilities(0)
205208
if opts.RXChecksumOffload {
206209
caps |= stack.CapabilityRXChecksumOffload
@@ -322,13 +325,18 @@ func (e *endpoint) Swap(fd int) (err error) {
322325
e.Lock()
323326
defer e.Unlock()
324327

328+
prevfd := e.fds.Load()
329+
if !prevfd.ok() {
330+
return errNeedsNewEndpoint
331+
}
332+
325333
f, err := newTun(fd) // fd may be invalid (ex: -1)
326334
if err != nil || f == nil {
327335
f = invalidFds // nilaway
328336
err = log.EE("ns: tun: swap: (%d) err: %v / %v; using invalidfd", fd, err)
329337
}
330338

331-
prevfd := e.fds.Swap(f) // commence WritePackets() on fd
339+
e.fds.Store(f) // commence WritePackets() on fd
332340

333341
log.D("ns: tun: swap: fd %s => %d; err? %v", prevfd, fd, err)
334342

intra/netstack/netstack.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"fmt"
1212
"io"
1313
"net/netip"
14+
"strconv"
1415
"strings"
16+
"sync"
1517
"syscall"
1618

1719
"github.com/celzero/firestack/intra/core"
@@ -35,11 +37,39 @@ const nicfwd = false
3537
// packets will be truncated to snapLen.
3638
const SnapLen uint32 = 2048 // in bytes; some sufficient value
3739

40+
var (
41+
errNoFdSwapper = errors.New("linkFdSwap: no FdSwapper")
42+
)
43+
3844
type linkFdSwap struct {
45+
sync.Mutex
3946
stack.LinkEndpoint
4047
FdSwapper
4148
}
4249

50+
// Swap implements FdSwapper.
51+
func (l *linkFdSwap) Swap(fd int) error {
52+
l.Lock()
53+
defer l.Unlock()
54+
55+
if l.FdSwapper == nil {
56+
return errNoFdSwapper
57+
}
58+
59+
err := l.FdSwapper.Swap(fd)
60+
if errors.Is(err, errNeedsNewEndpoint) {
61+
umtu := uint32(l.MTU())
62+
opt := Options{
63+
FDs: []int{fd},
64+
MTU: umtu,
65+
}
66+
core.Go("linkFdSwap."+strconv.Itoa(fd), l.LinkEndpoint.Close)
67+
l.LinkEndpoint, err = newFdbasedInjectableEndpoint(&opt)
68+
}
69+
70+
return err
71+
}
72+
4373
// ref: github.com/google/gvisor/blob/91f58d2cc/pkg/tcpip/sample/tun_tcp_echo/main.go#L102
4474
func NewEndpoint(dev, mtu int, sink io.WriteCloser) (ep SeamlessEndpoint, err error) {
4575
defer func() {
@@ -55,7 +85,7 @@ func NewEndpoint(dev, mtu int, sink io.WriteCloser) (ep SeamlessEndpoint, err er
5585
MTU: umtu,
5686
}
5787

58-
if ep, err = NewFdbasedInjectableEndpoint(&opt); err != nil {
88+
if ep, err = newFdbasedInjectableEndpoint(&opt); err != nil {
5989
return nil, err
6090
}
6191
// ref: github.com/google/gvisor/blob/aeabb785278/pkg/tcpip/link/sniffer/sniffer.go#L111-L131
@@ -70,7 +100,7 @@ func snoop(ep SeamlessEndpoint, sink io.WriteCloser) (SeamlessEndpoint, error) {
70100
if link, err := NewSnoopyEndpoint(ep, sink); err != nil {
71101
return nil, err
72102
} else {
73-
return linkFdSwap{link, ep}, nil
103+
return &linkFdSwap{sync.Mutex{}, link, ep}, nil
74104
}
75105
}
76106

0 commit comments

Comments
 (0)