@@ -40,14 +40,14 @@ import (
4040)
4141
4242type netTun struct {
43- ep * channel.Endpoint
44- stack * stack.Stack
45- events chan tun.Event
46- notifyHandle * channel. NotificationHandle
47- incomingPacket chan * buffer. View
48- mtu int
49- dnsServers []netip.Addr
50- hasV4 , hasV6 bool
43+ ep * channel.Endpoint
44+ stack * stack.Stack
45+ events chan tun.Event
46+ ctx context. Context
47+ cancel context. CancelFunc
48+ mtu int
49+ dnsServers []netip.Addr
50+ hasV4 , hasV6 bool
5151}
5252
5353type Net netTun
@@ -58,20 +58,23 @@ func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device,
5858 TransportProtocols : []stack.TransportProtocolFactory {tcp .NewProtocol , udp .NewProtocol , icmp .NewProtocol6 , icmp .NewProtocol4 },
5959 HandleLocal : true ,
6060 }
61+
62+ ctx , cancel := context .WithCancel (context .Background ())
63+
6164 dev := & netTun {
62- ep : channel .New (1024 , uint32 (mtu ), "" ),
63- stack : stack .New (opts ),
64- events : make (chan tun.Event , 10 ),
65- incomingPacket : make (chan * buffer.View ),
66- dnsServers : dnsServers ,
67- mtu : mtu ,
65+ ep : channel .New (1024 , uint32 (mtu ), "" ),
66+ stack : stack .New (opts ),
67+ events : make (chan tun.Event , 10 ),
68+ ctx : ctx ,
69+ cancel : cancel ,
70+ dnsServers : dnsServers ,
71+ mtu : mtu ,
6872 }
6973 sackEnabledOpt := tcpip .TCPSACKEnabled (true ) // TCP SACK is disabled by default
7074 tcpipErr := dev .stack .SetTransportProtocolOption (tcp .ProtocolNumber , & sackEnabledOpt )
7175 if tcpipErr != nil {
7276 return nil , nil , fmt .Errorf ("could not enable TCP SACK: %v" , tcpipErr )
7377 }
74- dev .notifyHandle = dev .ep .AddNotify (dev )
7578 tcpipErr = dev .stack .CreateNIC (1 , dev .ep )
7679 if tcpipErr != nil {
7780 return nil , nil , fmt .Errorf ("CreateNIC: %v" , tcpipErr )
@@ -121,11 +124,12 @@ func (tun *netTun) Events() <-chan tun.Event {
121124}
122125
123126func (tun * netTun ) Read (buf [][]byte , sizes []int , offset int ) (int , error ) {
124- view , ok := <- tun .incomingPacket
125- if ! ok {
127+ pkb := tun .ep . ReadContext ( tun . ctx )
128+ if pkb == nil {
126129 return 0 , os .ErrClosed
127130 }
128-
131+ view := pkb .ToView ()
132+ pkb .DecRef ()
129133 n , err := view .Read (buf [0 ][offset :])
130134 if err != nil {
131135 return 0 , err
@@ -135,6 +139,10 @@ func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) {
135139}
136140
137141func (tun * netTun ) Write (buf [][]byte , offset int ) (int , error ) {
142+ if tun .ctx .Err () != nil {
143+ return 0 , os .ErrClosed
144+ }
145+
138146 for _ , buf := range buf {
139147 packet := buf [offset :]
140148 if len (packet ) == 0 {
@@ -154,32 +162,16 @@ func (tun *netTun) Write(buf [][]byte, offset int) (int, error) {
154162 return len (buf ), nil
155163}
156164
157- func (tun * netTun ) WriteNotify () {
158- pkt := tun .ep .Read ()
159- if pkt == nil {
160- return
161- }
162-
163- view := pkt .ToView ()
164- pkt .DecRef ()
165-
166- tun .incomingPacket <- view
167- }
168-
169165func (tun * netTun ) Close () error {
166+ tun .cancel ()
170167 tun .stack .RemoveNIC (1 )
171168 tun .stack .Close ()
172- tun .ep .RemoveNotify (tun .notifyHandle )
173169 tun .ep .Close ()
174170
175171 if tun .events != nil {
176172 close (tun .events )
177173 }
178174
179- if tun .incomingPacket != nil {
180- close (tun .incomingPacket )
181- }
182-
183175 return nil
184176}
185177
0 commit comments