Skip to content

Commit e981b07

Browse files
committed
tun, device: allocate buffers in the edge devices
Signed-off-by: Alex Valiushko <alexvaliushko@tailscale.com> Change-Id: Ic309618efe9f78c82fbc071ab8111fa46a6a6964
1 parent b648d36 commit e981b07

26 files changed

Lines changed: 473 additions & 192 deletions

buffer/buffer.go

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
/* SPDX-License-Identifier: MIT
2+
*
3+
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4+
*
5+
* Package buffer implements a reusable buffer abstraction.
6+
*
7+
* Wireguard-go's data processing is constrained by both the hosts API,
8+
* and the transformations performed during encapsulation:
9+
*
10+
* 1. Encryption requires tail- and headroom for extra headers and padding.
11+
* Available via winrio, and pread(2).
12+
* 2. Systems are moving towards coalesced reads for both TCP and UDP.
13+
* The read data has no gaps for individual slices.
14+
* 3. crypto.AEAD interface requires a contiguous dst []byte for Sealing.
15+
* So we can't use scatter-gather to inject the gaps.
16+
*
17+
* Until one of these three conditions is changed, the encryption strategy
18+
* is to copy on read into buffers with the required gaps.
19+
* The buffers are right-sized for the packet to avoid memory inflation.
20+
* To recycle said allocations, each buffer carries a recycle function
21+
* that routes it back to its originating pool.
22+
*
23+
* Decryption shrinks each fragment instead of growing, so buffers can pass
24+
* through the pipeline without copying till the egress coalescion.
25+
* Depending on the chosen head of the coalescion, there may or may be no room
26+
* and reallocation is a necessary fallback until we start passing
27+
* buffers in batches.
28+
*/
29+
package buffer
30+
31+
const (
32+
MaxMessageSize = (1 << 16) - 1 // largest possible UDP datagram
33+
)
34+
35+
// Source produces new Buffers.
36+
type Source interface {
37+
Get(size int) *Buffer
38+
}
39+
40+
// Buffer is a reusable slice of bytes of fixed length.
41+
// The returned Data slice must not be retained past Release.
42+
type Buffer struct {
43+
data []byte
44+
recycle func(*Buffer)
45+
}
46+
47+
// New creates a standalone Buffer. Intended for use in tests.
48+
func New(b []byte) *Buffer {
49+
return &Buffer{data: b}
50+
}
51+
52+
func (b *Buffer) Data() []byte {
53+
return b.data
54+
}
55+
56+
func (b *Buffer) Len() int {
57+
return len(b.data)
58+
}
59+
60+
func (b *Buffer) Release() {
61+
if b.recycle != nil {
62+
memclr(b.data)
63+
b.recycle(b)
64+
}
65+
}
66+
67+
func ReleaseAll(bs []*Buffer) {
68+
for i := range bs {
69+
if bs[i] != nil {
70+
bs[i].Release()
71+
bs[i] = nil
72+
}
73+
}
74+
}
75+
76+
type Arena struct {
77+
*Buffer
78+
watermark int
79+
}
80+
81+
func (a *Arena) Get(size int) []byte {
82+
if a.watermark+size > len(a.Buffer.Data()) {
83+
panic("arena overflow") // or return a heap-allocated fallback
84+
}
85+
b := a.Buffer.Data()[a.watermark : a.watermark+size]
86+
a.watermark += size
87+
return b
88+
}
89+
90+
func (a *Arena) Flush() {
91+
memclr(a.Buffer.Data()[:a.watermark])
92+
a.watermark = 0
93+
}
94+
95+
func memclr(b []byte) {
96+
for i := range b {
97+
b[i] = 0
98+
}
99+
}

buffer/pool.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package buffer
2+
3+
import "sync"
4+
5+
const (
6+
min = 2 << 10
7+
mid = 10 << 10
8+
max = 65 << 10
9+
)
10+
11+
var _ Source = (*FragmentPool)(nil)
12+
13+
type FragmentPool struct {
14+
minPool sync.Pool
15+
midPool sync.Pool
16+
maxPool sync.Pool
17+
}
18+
19+
func NewFragmentPool() *FragmentPool {
20+
recycle := func(p *sync.Pool) func(*Buffer) {
21+
return func(b *Buffer) {
22+
b.data = b.data[:cap(b.data)]
23+
p.Put(b)
24+
}
25+
}
26+
p := new(FragmentPool)
27+
p.minPool.New = func() any {
28+
return &Buffer{data: make([]byte, min), recycle: recycle(&p.minPool)}
29+
}
30+
p.midPool.New = func() any {
31+
return &Buffer{data: make([]byte, mid), recycle: recycle(&p.midPool)}
32+
}
33+
p.maxPool.New = func() any {
34+
return &Buffer{data: make([]byte, max), recycle: recycle(&p.maxPool)}
35+
}
36+
return p
37+
}
38+
39+
func (p *FragmentPool) Get(size int) *Buffer {
40+
var buf *Buffer
41+
switch {
42+
case size <= min:
43+
buf = p.minPool.Get().(*Buffer)
44+
case size <= mid:
45+
buf = p.midPool.Get().(*Buffer)
46+
case size <= max:
47+
buf = p.maxPool.Get().(*Buffer)
48+
default:
49+
return &Buffer{data: make([]byte, size)}
50+
}
51+
buf.data = buf.data[:size]
52+
return buf
53+
}

conn/bind_std.go

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"sync"
1717
"syscall"
1818

19+
"github.com/tailscale/wireguard-go/buffer"
1920
"golang.org/x/net/ipv4"
2021
"golang.org/x/net/ipv6"
2122
)
@@ -44,6 +45,7 @@ type StdNetBind struct {
4445
// these two fields are not guarded by mu
4546
udpAddrPool sync.Pool
4647
msgsPool sync.Pool
48+
bufPool *buffer.FragmentPool
4749

4850
blackhole4 bool
4951
blackhole6 bool
@@ -63,12 +65,14 @@ func NewStdNetBind() Bind {
6365
New: func() any {
6466
msgs := make([]ipv6.Message, IdealBatchSize)
6567
for i := range msgs {
66-
msgs[i].Buffers = make(net.Buffers, 1)
68+
msgs[i].Buffers = make(net.Buffers, 1, udpSegmentMaxDatagrams)
6769
msgs[i].OOB = make([]byte, controlSize)
6870
}
6971
return &msgs
7072
},
7173
},
74+
75+
bufPool: buffer.NewFragmentPool(),
7276
}
7377
}
7478

@@ -204,7 +208,7 @@ again:
204208

205209
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
206210
for i := range *msgs {
207-
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
211+
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers[:1], OOB: (*msgs)[i].OOB}
208212
}
209213
s.msgsPool.Put(msgs)
210214
}
@@ -230,36 +234,52 @@ func (s *StdNetBind) receiveIP(
230234
br batchReader,
231235
conn *net.UDPConn,
232236
rxOffload bool,
233-
bufs [][]byte,
237+
bufs []*buffer.Buffer,
234238
sizes []int,
235239
eps []Endpoint,
236240
) (n int, err error) {
237241
msgs := s.getMessages()
238-
for i := range bufs {
239-
(*msgs)[i].Buffers[0] = bufs[i]
240-
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
241-
}
242242
defer s.putMessages(msgs)
243243
var numMsgs int
244244
if runtime.GOOS == "linux" {
245245
if rxOffload {
246-
readAt := len(*msgs) - 2
246+
const readBatch = 2
247+
readAt := len(*msgs) - readBatch
248+
for i := readAt; i < readAt+readBatch; i++ {
249+
if bufs[i] == nil {
250+
bufs[i] = s.bufPool.Get(buffer.MaxMessageSize)
251+
}
252+
(*msgs)[i].Buffers[0] = bufs[i].Data()
253+
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
254+
}
247255
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
248256
if err != nil {
249257
return 0, err
250258
}
251-
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
259+
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize, bufs, s.bufPool)
252260
if err != nil {
253261
return 0, err
254262
}
255263
} else {
264+
for i := range bufs {
265+
if bufs[i] == nil {
266+
bufs[i] = s.bufPool.Get(buffer.MaxMessageSize)
267+
}
268+
(*msgs)[i].Buffers[0] = bufs[i].Data()
269+
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
270+
}
256271
numMsgs, err = br.ReadBatch(*msgs, 0)
257272
if err != nil {
258273
return 0, err
259274
}
260275
}
261276
} else {
277+
if bufs[0] == nil {
278+
bufs[0] = s.bufPool.Get(buffer.MaxMessageSize)
279+
}
262280
msg := &(*msgs)[0]
281+
msg.Buffers[0] = bufs[0].Data()
282+
msg.OOB = msg.OOB[:cap(msg.OOB)]
263283
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
264284
if err != nil {
265285
return 0, err
@@ -281,13 +301,13 @@ func (s *StdNetBind) receiveIP(
281301
}
282302

283303
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
284-
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
304+
return func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) {
285305
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
286306
}
287307
}
288308

289309
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
290-
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
310+
return func(bufs []*buffer.Buffer, sizes []int, eps []Endpoint) (n int, err error) {
291311
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
292312
}
293313
}
@@ -452,10 +472,11 @@ type setGSOFunc func(control *[]byte, gsoSize uint16)
452472

453473
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offset int, msgs []ipv6.Message, setGSO setGSOFunc) int {
454474
var (
455-
base = -1 // index of msg we are currently coalescing into
456-
gsoSize int // segmentation size of msgs[base]
457-
dgramCnt int // number of dgrams coalesced into msgs[base]
458-
endBatch bool // tracking flag to start a new batch on next iteration of bufs
475+
base = -1 // index of msg we are currently coalescing into
476+
gsoSize int // segmentation size of msgs[base]
477+
dgramCnt int // number of dgrams coalesced into msgs[base]
478+
endBatch bool // tracking flag to start a new batch on next iteration of bufs
479+
coalescedLen int // bytes coalesced into msgs[base]
459480
)
460481
maxPayloadLen := maxIPv4PayloadLen
461482
if ep.DstIP().Is6() {
@@ -465,18 +486,16 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs
465486
buf = buf[offset:]
466487
if i > 0 {
467488
msgLen := len(buf)
468-
baseLenBefore := len(msgs[base].Buffers[0])
469-
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
470-
if msgLen+baseLenBefore <= maxPayloadLen &&
489+
if msgLen+coalescedLen <= maxPayloadLen &&
471490
msgLen <= gsoSize &&
472-
msgLen <= freeBaseCap &&
473491
dgramCnt < udpSegmentMaxDatagrams &&
474492
!endBatch {
475-
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
493+
msgs[base].Buffers = append(msgs[base].Buffers, buf)
476494
if i == len(bufs)-1 {
477495
setGSO(&msgs[base].OOB, uint16(gsoSize))
478496
}
479497
dgramCnt++
498+
coalescedLen += msgLen
480499
if msgLen < gsoSize {
481500
// A smaller than gsoSize packet on the tail is legal, but
482501
// it must end the batch.
@@ -497,13 +516,14 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs
497516
msgs[base].Buffers[0] = buf
498517
msgs[base].Addr = addr
499518
dgramCnt = 1
519+
coalescedLen = gsoSize
500520
}
501521
return base + 1
502522
}
503523

504524
type getGSOFunc func(control []byte) (int, error)
505525

506-
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
526+
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc, bufs []*buffer.Buffer, pool buffer.Source) (n int, err error) {
507527
for i := firstMsgAt; i < len(msgs); i++ {
508528
msg := &msgs[i]
509529
if msg.N == 0 {
@@ -527,6 +547,12 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu
527547
if n > i {
528548
return n, errors.New("splitting coalesced packet resulted in overflow")
529549
}
550+
segLen := end - start
551+
if bufs[n] == nil {
552+
bufs[n] = pool.Get(segLen)
553+
}
554+
msgs[n].Buffers[0] = bufs[n].Data()
555+
msgs[n].OOB = msgs[n].OOB[:cap(msgs[n].OOB)]
530556
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
531557
msgs[n].N = copied
532558
msgs[n].Addr = msg.Addr

0 commit comments

Comments
 (0)