@@ -16,6 +16,7 @@ import (
1616 "sync"
1717 "syscall"
1818
19+ "github.com/tailscale/wireguard-go/buffer"
1920 "golang.org/x/net/ipv4"
2021 "golang.org/x/net/ipv6"
2122)
@@ -44,6 +45,7 @@ type StdNetBind struct {
4445 // these two fields are not guarded by mu
4546 udpAddrPool sync.Pool
4647 msgsPool sync.Pool
48+ bufPool * buffer.FragmentPool
4749
4850 blackhole4 bool
4951 blackhole6 bool
@@ -63,12 +65,14 @@ func NewStdNetBind() Bind {
6365 New : func () any {
6466 msgs := make ([]ipv6.Message , IdealBatchSize )
6567 for i := range msgs {
66- msgs [i ].Buffers = make (net.Buffers , 1 )
68+ msgs [i ].Buffers = make (net.Buffers , 1 , udpSegmentMaxDatagrams )
6769 msgs [i ].OOB = make ([]byte , controlSize )
6870 }
6971 return & msgs
7072 },
7173 },
74+
75+ bufPool : buffer .NewFragmentPool (),
7276 }
7377}
7478
@@ -204,7 +208,7 @@ again:
204208
205209func (s * StdNetBind ) putMessages (msgs * []ipv6.Message ) {
206210 for i := range * msgs {
207- (* msgs )[i ] = ipv6.Message {Buffers : (* msgs )[i ].Buffers , OOB : (* msgs )[i ].OOB }
211+ (* msgs )[i ] = ipv6.Message {Buffers : (* msgs )[i ].Buffers [: 1 ] , OOB : (* msgs )[i ].OOB }
208212 }
209213 s .msgsPool .Put (msgs )
210214}
@@ -230,36 +234,52 @@ func (s *StdNetBind) receiveIP(
230234 br batchReader ,
231235 conn * net.UDPConn ,
232236 rxOffload bool ,
233- bufs [][] byte ,
237+ bufs []* buffer. Buffer ,
234238 sizes []int ,
235239 eps []Endpoint ,
236240) (n int , err error ) {
237241 msgs := s .getMessages ()
238- for i := range bufs {
239- (* msgs )[i ].Buffers [0 ] = bufs [i ]
240- (* msgs )[i ].OOB = (* msgs )[i ].OOB [:cap ((* msgs )[i ].OOB )]
241- }
242242 defer s .putMessages (msgs )
243243 var numMsgs int
244244 if runtime .GOOS == "linux" {
245245 if rxOffload {
246- readAt := len (* msgs ) - 2
246+ const readBatch = 2
247+ readAt := len (* msgs ) - readBatch
248+ for i := readAt ; i < readAt + readBatch ; i ++ {
249+ if bufs [i ] == nil {
250+ bufs [i ] = s .bufPool .Get (buffer .MaxMessageSize )
251+ }
252+ (* msgs )[i ].Buffers [0 ] = bufs [i ].Data ()
253+ (* msgs )[i ].OOB = (* msgs )[i ].OOB [:cap ((* msgs )[i ].OOB )]
254+ }
247255 numMsgs , err = br .ReadBatch ((* msgs )[readAt :], 0 )
248256 if err != nil {
249257 return 0 , err
250258 }
251- numMsgs , err = splitCoalescedMessages (* msgs , readAt , getGSOSize )
259+ numMsgs , err = splitCoalescedMessages (* msgs , readAt , getGSOSize , bufs , s . bufPool )
252260 if err != nil {
253261 return 0 , err
254262 }
255263 } else {
264+ for i := range bufs {
265+ if bufs [i ] == nil {
266+ bufs [i ] = s .bufPool .Get (buffer .MaxMessageSize )
267+ }
268+ (* msgs )[i ].Buffers [0 ] = bufs [i ].Data ()
269+ (* msgs )[i ].OOB = (* msgs )[i ].OOB [:cap ((* msgs )[i ].OOB )]
270+ }
256271 numMsgs , err = br .ReadBatch (* msgs , 0 )
257272 if err != nil {
258273 return 0 , err
259274 }
260275 }
261276 } else {
277+ if bufs [0 ] == nil {
278+ bufs [0 ] = s .bufPool .Get (buffer .MaxMessageSize )
279+ }
262280 msg := & (* msgs )[0 ]
281+ msg .Buffers [0 ] = bufs [0 ].Data ()
282+ msg .OOB = msg .OOB [:cap (msg .OOB )]
263283 msg .N , msg .NN , _ , msg .Addr , err = conn .ReadMsgUDP (msg .Buffers [0 ], msg .OOB )
264284 if err != nil {
265285 return 0 , err
@@ -281,13 +301,13 @@ func (s *StdNetBind) receiveIP(
281301}
282302
283303func (s * StdNetBind ) makeReceiveIPv4 (pc * ipv4.PacketConn , conn * net.UDPConn , rxOffload bool ) ReceiveFunc {
284- return func (bufs [][] byte , sizes []int , eps []Endpoint ) (n int , err error ) {
304+ return func (bufs []* buffer. Buffer , sizes []int , eps []Endpoint ) (n int , err error ) {
285305 return s .receiveIP (pc , conn , rxOffload , bufs , sizes , eps )
286306 }
287307}
288308
289309func (s * StdNetBind ) makeReceiveIPv6 (pc * ipv6.PacketConn , conn * net.UDPConn , rxOffload bool ) ReceiveFunc {
290- return func (bufs [][] byte , sizes []int , eps []Endpoint ) (n int , err error ) {
310+ return func (bufs []* buffer. Buffer , sizes []int , eps []Endpoint ) (n int , err error ) {
291311 return s .receiveIP (pc , conn , rxOffload , bufs , sizes , eps )
292312 }
293313}
@@ -452,10 +472,11 @@ type setGSOFunc func(control *[]byte, gsoSize uint16)
452472
453473func coalesceMessages (addr * net.UDPAddr , ep * StdNetEndpoint , bufs [][]byte , offset int , msgs []ipv6.Message , setGSO setGSOFunc ) int {
454474 var (
455- base = - 1 // index of msg we are currently coalescing into
456- gsoSize int // segmentation size of msgs[base]
457- dgramCnt int // number of dgrams coalesced into msgs[base]
458- endBatch bool // tracking flag to start a new batch on next iteration of bufs
475+ base = - 1 // index of msg we are currently coalescing into
476+ gsoSize int // segmentation size of msgs[base]
477+ dgramCnt int // number of dgrams coalesced into msgs[base]
478+ endBatch bool // tracking flag to start a new batch on next iteration of bufs
479+ coalescedLen int // bytes coalesced into msgs[base]
459480 )
460481 maxPayloadLen := maxIPv4PayloadLen
461482 if ep .DstIP ().Is6 () {
@@ -465,18 +486,16 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs
465486 buf = buf [offset :]
466487 if i > 0 {
467488 msgLen := len (buf )
468- baseLenBefore := len (msgs [base ].Buffers [0 ])
469- freeBaseCap := cap (msgs [base ].Buffers [0 ]) - baseLenBefore
470- if msgLen + baseLenBefore <= maxPayloadLen &&
489+ if msgLen + coalescedLen <= maxPayloadLen &&
471490 msgLen <= gsoSize &&
472- msgLen <= freeBaseCap &&
473491 dgramCnt < udpSegmentMaxDatagrams &&
474492 ! endBatch {
475- msgs [base ].Buffers [ 0 ] = append (msgs [base ].Buffers [ 0 ] , buf ... )
493+ msgs [base ].Buffers = append (msgs [base ].Buffers , buf )
476494 if i == len (bufs )- 1 {
477495 setGSO (& msgs [base ].OOB , uint16 (gsoSize ))
478496 }
479497 dgramCnt ++
498+ coalescedLen += msgLen
480499 if msgLen < gsoSize {
481500 // A smaller than gsoSize packet on the tail is legal, but
482501 // it must end the batch.
@@ -497,13 +516,14 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs
497516 msgs [base ].Buffers [0 ] = buf
498517 msgs [base ].Addr = addr
499518 dgramCnt = 1
519+ coalescedLen = gsoSize
500520 }
501521 return base + 1
502522}
503523
504524type getGSOFunc func (control []byte ) (int , error )
505525
506- func splitCoalescedMessages (msgs []ipv6.Message , firstMsgAt int , getGSO getGSOFunc ) (n int , err error ) {
526+ func splitCoalescedMessages (msgs []ipv6.Message , firstMsgAt int , getGSO getGSOFunc , bufs [] * buffer. Buffer , pool buffer. Source ) (n int , err error ) {
507527 for i := firstMsgAt ; i < len (msgs ); i ++ {
508528 msg := & msgs [i ]
509529 if msg .N == 0 {
@@ -527,6 +547,12 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu
527547 if n > i {
528548 return n , errors .New ("splitting coalesced packet resulted in overflow" )
529549 }
550+ segLen := end - start
551+ if bufs [n ] == nil {
552+ bufs [n ] = pool .Get (segLen )
553+ }
554+ msgs [n ].Buffers [0 ] = bufs [n ].Data ()
555+ msgs [n ].OOB = msgs [n ].OOB [:cap (msgs [n ].OOB )]
530556 copied := copy (msgs [n ].Buffers [0 ], msg .Buffers [0 ][start :end ])
531557 msgs [n ].N = copied
532558 msgs [n ].Addr = msg .Addr
0 commit comments