Skip to content

Commit 61b6c72

Browse files
committed
!fixup use fat pointers
Signed-off-by: Alex Valiushko <alexvaliushko@tailscale.com> Change-Id: Ibf50de1c5068b02c17ab37fac8bb03616a6a6964
1 parent 7800bcf commit 61b6c72

16 files changed

Lines changed: 555 additions & 490 deletions

buffer/buffer.go

Lines changed: 57 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,27 @@ import (
66

77
const (
88
MaxMessageSize = (1 << 16) - 1 // largest possible UDP datagram
9+
MaxBufsPerRead = 128
910
)
1011

11-
type Pool interface {
12-
Get(minSize int) []byte
13-
Put([]byte)
12+
type Buffer struct {
13+
Bs []byte
14+
Owner Pool
1415
}
1516

16-
type TieredPool struct {
17-
minPool sync.Pool
18-
midPool sync.Pool
19-
maxPool sync.Pool
17+
// Release returns b to its Owner pool. The caller must not use b after
18+
// calling Release. It is a no-op if Owner is nil (e.g. GC-managed buffers).
19+
func (b *Buffer) Release() {
20+
if b.Owner != nil {
21+
b.Owner.Put(b)
22+
}
23+
}
24+
25+
// Pool manages *Buffer allocations. Get returns a *Buffer with Bs of at
26+
// least minSize length. Put returns a *Buffer for reuse.
27+
type Pool interface {
28+
Get(minSize int) *Buffer
29+
Put(*Buffer)
2030
}
2131

2232
const (
@@ -25,65 +35,67 @@ const (
2535
max = 1 << 16
2636
)
2737

38+
type TieredPool struct {
39+
minPool sync.Pool
40+
midPool sync.Pool
41+
maxPool sync.Pool
42+
}
43+
2844
func NewTieredPool() Pool {
29-
return &TieredPool{
30-
minPool: sync.Pool{
31-
New: func() any {
32-
return make([]byte, min)
33-
},
34-
},
35-
midPool: sync.Pool{
36-
New: func() any {
37-
return make([]byte, mid)
38-
},
39-
},
40-
maxPool: sync.Pool{
41-
New: func() any {
42-
return make([]byte, max)
43-
},
44-
},
45+
p := new(TieredPool)
46+
p.minPool.New = func() any {
47+
return &Buffer{Bs: make([]byte, min), Owner: p}
4548
}
49+
p.midPool.New = func() any {
50+
return &Buffer{Bs: make([]byte, mid), Owner: p}
51+
}
52+
p.maxPool.New = func() any {
53+
return &Buffer{Bs: make([]byte, max), Owner: p}
54+
}
55+
return p
4656
}
4757

48-
func (p *TieredPool) Get(minSize int) []byte {
49-
if minSize > max {
50-
panic("requested buffer size exceeds maximum allowed")
51-
}
58+
func (p *TieredPool) Get(minSize int) *Buffer {
5259
switch {
5360
case minSize <= min:
54-
return p.minPool.Get().([]byte)[:minSize]
61+
buf := p.minPool.Get().(*Buffer)
62+
buf.Bs = buf.Bs[:minSize]
63+
return buf
5564
case minSize <= mid:
56-
return p.midPool.Get().([]byte)[:minSize]
65+
buf := p.midPool.Get().(*Buffer)
66+
buf.Bs = buf.Bs[:minSize]
67+
return buf
5768
case minSize <= max:
58-
return p.maxPool.Get().([]byte)[:minSize]
69+
buf := p.maxPool.Get().(*Buffer)
70+
buf.Bs = buf.Bs[:minSize]
71+
return buf
5972
default:
60-
return make([]byte, minSize)
73+
return &Buffer{Bs: make([]byte, minSize)} // GC-managed
6174
}
6275
}
6376

64-
func (p *TieredPool) Put(b []byte) {
65-
if b == nil {
66-
return
67-
}
68-
switch cap(b) {
77+
func (p *TieredPool) Put(b *Buffer) {
78+
b.Owner = p
79+
switch cap(b.Bs) {
6980
case min:
81+
b.Bs = b.Bs[:cap(b.Bs)]
7082
p.minPool.Put(b)
7183
case mid:
84+
b.Bs = b.Bs[:cap(b.Bs)]
7285
p.midPool.Put(b)
7386
case max:
87+
b.Bs = b.Bs[:cap(b.Bs)]
7488
p.maxPool.Put(b)
7589
default:
76-
// Drop to GC
90+
// GC-managed
7791
}
7892
}
7993

80-
func Grow(b []byte, minSize int, pool Pool) []byte {
81-
if pool == nil {
82-
return b
83-
}
84-
if cap(b) >= minSize {
85-
return b[:minSize]
94+
func ReleaseAll(ms []*Buffer) {
95+
for i := range ms {
96+
if ms[i] != nil {
97+
ms[i].Release()
98+
ms[i] = nil
99+
}
86100
}
87-
pool.Put(b)
88-
return pool.Get(minSize)
89101
}

conn/bind_std.go

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ import (
2222
)
2323

2424
var (
25-
_ Bind = (*StdNetBind)(nil)
26-
_ BufferPoolBind = (*StdNetBind)(nil)
27-
_ Endpoint = (*StdNetEndpoint)(nil)
25+
_ Bind = (*StdNetBind)(nil)
26+
_ BufferBind = (*StdNetBind)(nil)
27+
_ Endpoint = (*StdNetEndpoint)(nil)
2828
)
2929

3030
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
@@ -77,10 +77,12 @@ func NewStdNetBind() Bind {
7777
}
7878
}
7979

80-
// PutBindBuffer implements BufferPoolBind, returning a buffer to the
81-
// internal pool.
82-
func (s *StdNetBind) PutBindBuffer(buf []byte) {
83-
s.bufPool.Put(buf)
80+
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
81+
panic("use BufferOpen")
82+
}
83+
84+
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error {
85+
panic("use BufferSend")
8486
}
8587

8688
type StdNetEndpoint struct {
@@ -151,7 +153,8 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
151153
// it's at least non-nil.
152154
var errEADDRINUSE error = errors.New("")
153155

154-
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
156+
// BufferOpen implements BufferBind. It is the primary socket setup path.
157+
func (s *StdNetBind) BufferOpen(uport uint16) ([]ReceiveBufferFunc, uint16, error) {
155158
s.mu.Lock()
156159
defer s.mu.Unlock()
157160

@@ -186,7 +189,7 @@ again:
186189
v4conn.Close()
187190
return nil, 0, err
188191
}
189-
var fns []ReceiveFunc
192+
var fns []ReceiveBufferFunc
190193
if v4conn != nil {
191194
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
192195
if runtime.GOOS == "linux" {
@@ -240,8 +243,7 @@ func (s *StdNetBind) receiveIP(
240243
br batchReader,
241244
conn *net.UDPConn,
242245
rxOffload bool,
243-
bufs [][]byte,
244-
sizes []int,
246+
bufs []*buffer.Buffer,
245247
eps []Endpoint,
246248
) (n int, err error) {
247249
msgs := s.getMessages()
@@ -251,40 +253,27 @@ func (s *StdNetBind) receiveIP(
251253
if rxOffload {
252254
const readBatch = 2
253255
readAt := len(*msgs) - readBatch
254-
// Only the read-at slots need full-size buffers for the
255-
// coalesced read. Destination slots (0..readAt-1) are left
256-
// as-is for splitCoalescedMessages to right-size.
257256
for i := readAt; i < readAt+readBatch; i++ {
258257
if bufs[i] == nil {
259258
bufs[i] = s.bufPool.Get(buffer.MaxMessageSize)
260259
}
261-
(*msgs)[i].Buffers[0] = bufs[i]
260+
(*msgs)[i].Buffers[0] = bufs[i].Bs
262261
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
263262
}
264-
// Wire destination slot buffers into msgs (may be nil or
265-
// segment-sized leftovers from a previous call).
266-
for i := 0; i < readAt; i++ {
267-
(*msgs)[i].Buffers[0] = bufs[i]
268-
}
269263
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
270264
if err != nil {
271265
return 0, err
272266
}
273-
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize, s.bufPool)
267+
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize, bufs, s.bufPool)
274268
if err != nil {
275269
return 0, err
276270
}
277-
// Sync bufs from msgs after split — splitCoalescedMessages may
278-
// have allocated or replaced buffers for destination slots.
279-
for i := 0; i < numMsgs; i++ {
280-
bufs[i] = (*msgs)[i].Buffers[0]
281-
}
282271
} else {
283272
for i := range bufs {
284273
if bufs[i] == nil {
285274
bufs[i] = s.bufPool.Get(buffer.MaxMessageSize)
286275
}
287-
(*msgs)[i].Buffers[0] = bufs[i]
276+
(*msgs)[i].Buffers[0] = bufs[i].Bs
288277
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
289278
}
290279
numMsgs, err = br.ReadBatch(*msgs, 0)
@@ -297,7 +286,7 @@ func (s *StdNetBind) receiveIP(
297286
bufs[0] = s.bufPool.Get(buffer.MaxMessageSize)
298287
}
299288
msg := &(*msgs)[0]
300-
msg.Buffers[0] = bufs[0]
289+
msg.Buffers[0] = bufs[0].Bs
301290
msg.OOB = msg.OOB[:cap(msg.OOB)]
302291
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
303292
if err != nil {
@@ -307,10 +296,10 @@ func (s *StdNetBind) receiveIP(
307296
}
308297
for i := 0; i < numMsgs; i++ {
309298
msg := &(*msgs)[i]
310-
sizes[i] = msg.N
311-
if sizes[i] == 0 {
299+
if msg.N == 0 {
312300
continue
313301
}
302+
bufs[i].Bs = bufs[i].Bs[:msg.N]
314303
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
315304
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
316305
getSrcFromControl(msg.OOB[:msg.NN], ep)
@@ -319,15 +308,15 @@ func (s *StdNetBind) receiveIP(
319308
return numMsgs, nil
320309
}
321310

322-
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
323-
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
324-
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
311+
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveBufferFunc {
312+
return func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) {
313+
return s.receiveIP(pc, conn, rxOffload, bufs, eps)
325314
}
326315
}
327316

328-
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
329-
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
330-
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
317+
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveBufferFunc {
318+
return func(bufs []*buffer.Buffer, eps []Endpoint) (n int, err error) {
319+
return s.receiveIP(pc, conn, rxOffload, bufs, eps)
331320
}
332321
}
333322

@@ -380,7 +369,9 @@ func (e ErrUDPGSODisabled) Unwrap() error {
380369
return e.RetryErr
381370
}
382371

383-
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error {
372+
// BufferSend implements BufferBind. It takes ownership of bufs and releases
373+
// them after the send completes. This is the primary send path.
374+
func (s *StdNetBind) BufferSend(bufs []*buffer.Buffer, endpoint Endpoint, offset int) error {
384375
s.mu.Lock()
385376
blackhole := s.blackhole4
386377
conn := s.ipv4
@@ -397,9 +388,11 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error {
397388
s.mu.Unlock()
398389

399390
if blackhole {
391+
buffer.ReleaseAll(bufs)
400392
return nil
401393
}
402394
if conn == nil {
395+
buffer.ReleaseAll(bufs)
403396
return syscall.EAFNOSUPPORT
404397
}
405398

@@ -423,7 +416,11 @@ func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint, offset int) error {
423416
)
424417
retry:
425418
if offload {
426-
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, offset, *msgs, setGSOSize)
419+
rawBufs := make([][]byte, len(bufs))
420+
for i, b := range bufs {
421+
rawBufs[i] = b.Bs
422+
}
423+
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), rawBufs, offset, *msgs, setGSOSize)
427424
err = s.send(conn, br, (*msgs)[:n])
428425
if err != nil && offload && errShouldDisableUDPGSO(err) {
429426
offload = false
@@ -440,11 +437,12 @@ retry:
440437
} else {
441438
for i := range bufs {
442439
(*msgs)[i].Addr = ua
443-
(*msgs)[i].Buffers[0] = bufs[i][offset:]
440+
(*msgs)[i].Buffers[0] = bufs[i].Bs[offset:]
444441
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
445442
}
446443
err = s.send(conn, br, (*msgs)[:len(bufs)])
447444
}
445+
buffer.ReleaseAll(bufs)
448446
if retried {
449447
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
450448
}
@@ -542,7 +540,7 @@ func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, offs
542540

543541
type getGSOFunc func(control []byte) (int, error)
544542

545-
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc, pool buffer.Pool) (n int, err error) {
543+
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc, bufs []*buffer.Buffer, pool buffer.Pool) (n int, err error) {
546544
for i := firstMsgAt; i < len(msgs); i++ {
547545
msg := &msgs[i]
548546
if msg.N == 0 {
@@ -567,7 +565,11 @@ func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFu
567565
return n, errors.New("splitting coalesced packet resulted in overflow")
568566
}
569567
segLen := end - start
570-
msgs[n].Buffers[0] = buffer.Grow(msgs[n].Buffers[0], segLen, pool)
568+
if bufs[n] == nil {
569+
bufs[n] = pool.Get(segLen)
570+
}
571+
msgs[n].Buffers[0] = bufs[n].Bs
572+
msgs[n].OOB = msgs[n].OOB[:cap(msgs[n].OOB)]
571573
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
572574
msgs[n].N = copied
573575
msgs[n].Addr = msg.Addr

0 commit comments

Comments
 (0)