Skip to content

Commit 078cbb5

Browse files
committed
refactor(tun): use context for packet handling and cleanup
1 parent f333402 commit 078cbb5

1 file changed

Lines changed: 27 additions & 35 deletions

File tree

tun/netstack/tun.go

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ import (
4040
)
4141

4242
type netTun struct {
43-
ep *channel.Endpoint
44-
stack *stack.Stack
45-
events chan tun.Event
46-
notifyHandle *channel.NotificationHandle
47-
incomingPacket chan *buffer.View
48-
mtu int
49-
dnsServers []netip.Addr
50-
hasV4, hasV6 bool
43+
ep *channel.Endpoint
44+
stack *stack.Stack
45+
events chan tun.Event
46+
ctx context.Context
47+
cancel context.CancelFunc
48+
mtu int
49+
dnsServers []netip.Addr
50+
hasV4, hasV6 bool
5151
}
5252

5353
type Net netTun
@@ -58,20 +58,23 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
5858
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol6, icmp.NewProtocol4},
5959
HandleLocal: true,
6060
}
61+
62+
ctx, cancel := context.WithCancel(context.Background())
63+
6164
dev := &netTun{
62-
ep: channel.New(1024, uint32(mtu), ""),
63-
stack: stack.New(opts),
64-
events: make(chan tun.Event, 10),
65-
incomingPacket: make(chan *buffer.View),
66-
dnsServers: dnsServers,
67-
mtu: mtu,
65+
ep: channel.New(1024, uint32(mtu), ""),
66+
stack: stack.New(opts),
67+
events: make(chan tun.Event, 10),
68+
ctx: ctx,
69+
cancel: cancel,
70+
dnsServers: dnsServers,
71+
mtu: mtu,
6872
}
6973
sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
7074
tcpipErr := dev.stack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
7175
if tcpipErr != nil {
7276
return nil, nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
7377
}
74-
dev.notifyHandle = dev.ep.AddNotify(dev)
7578
tcpipErr = dev.stack.CreateNIC(1, dev.ep)
7679
if tcpipErr != nil {
7780
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
@@ -121,11 +124,12 @@ func (tun *netTun) Events() <-chan tun.Event {
121124
}
122125

123126
func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
124-
view, ok := <-tun.incomingPacket
125-
if !ok {
127+
pkb := tun.ep.ReadContext(tun.ctx)
128+
if pkb == nil {
126129
return 0, os.ErrClosed
127130
}
128-
131+
view := pkb.ToView()
132+
pkb.DecRef()
129133
n, err := view.Read(buf[0][offset:])
130134
if err != nil {
131135
return 0, err
@@ -135,6 +139,10 @@ func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
135139
}
136140

137141
func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
142+
if tun.ctx.Err() != nil {
143+
return 0, os.ErrClosed
144+
}
145+
138146
for _, buf := range buf {
139147
packet := buf[offset:]
140148
if len(packet) == 0 {
@@ -154,32 +162,16 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
154162
return len(buf), nil
155163
}
156164

157-
func (tun *netTun) WriteNotify() {
158-
pkt := tun.ep.Read()
159-
if pkt == nil {
160-
return
161-
}
162-
163-
view := pkt.ToView()
164-
pkt.DecRef()
165-
166-
tun.incomingPacket <- view
167-
}
168-
169165
func (tun *netTun) Close() error {
166+
tun.cancel()
170167
tun.stack.RemoveNIC(1)
171168
tun.stack.Close()
172-
tun.ep.RemoveNotify(tun.notifyHandle)
173169
tun.ep.Close()
174170

175171
if tun.events != nil {
176172
close(tun.events)
177173
}
178174

179-
if tun.incomingPacket != nil {
180-
close(tun.incomingPacket)
181-
}
182-
183175
return nil
184176
}
185177

0 commit comments

Comments
 (0)