Skip to content

Commit 3ade42b

Browse files
committed
device: extract pool into standalone package
Also fixes Put-before-Get underflow bug. Signed-off-by: Alex Valiushko <alexvaliushko@tailscale.com> Change-Id: Icde836d1bef2e95850d3ae51b3c069056a6a6964
1 parent 7f4b4ad commit 3ade42b

7 files changed

Lines changed: 83 additions & 58 deletions

File tree

device/device.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/tailscale/wireguard-go/ratelimiter"
1818
"github.com/tailscale/wireguard-go/rwcancel"
1919
"github.com/tailscale/wireguard-go/tun"
20+
"github.com/tailscale/wireguard-go/waitpool"
2021
)
2122

2223
type Device struct {
@@ -71,11 +72,11 @@ type Device struct {
7172
cookieChecker CookieChecker
7273

7374
pool struct {
74-
inboundElementsContainer *WaitPool
75-
outboundElementsContainer *WaitPool
76-
messageBuffers *WaitPool
77-
inboundElements *WaitPool
78-
outboundElements *WaitPool
75+
inboundElementsContainer *waitpool.WaitPool
76+
outboundElementsContainer *waitpool.WaitPool
77+
messageBuffers *waitpool.WaitPool
78+
inboundElements *waitpool.WaitPool
79+
outboundElements *waitpool.WaitPool
7980
}
8081

8182
queue struct {

device/pools.go

Lines changed: 7 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,61 +7,26 @@ package device
77

88
import (
99
"sync"
10-
)
11-
12-
type WaitPool struct {
13-
pool sync.Pool
14-
cond sync.Cond
15-
lock sync.Mutex
16-
count uint32 // Get calls not yet Put back
17-
max uint32
18-
}
1910

20-
func NewWaitPool(max uint32, new func() any) *WaitPool {
21-
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
22-
p.cond = sync.Cond{L: &p.lock}
23-
return p
24-
}
25-
26-
func (p *WaitPool) Get() any {
27-
if p.max != 0 {
28-
p.lock.Lock()
29-
for p.count >= p.max {
30-
p.cond.Wait()
31-
}
32-
p.count++
33-
p.lock.Unlock()
34-
}
35-
return p.pool.Get()
36-
}
37-
38-
func (p *WaitPool) Put(x any) {
39-
p.pool.Put(x)
40-
if p.max == 0 {
41-
return
42-
}
43-
p.lock.Lock()
44-
defer p.lock.Unlock()
45-
p.count--
46-
p.cond.Signal()
47-
}
11+
"github.com/tailscale/wireguard-go/waitpool"
12+
)
4813

4914
func (device *Device) PopulatePools() {
50-
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
15+
device.pool.inboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any {
5116
s := make([]*QueueInboundElement, 0, device.BatchSize())
5217
return &QueueInboundElementsContainer{elems: s}
5318
})
54-
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
19+
device.pool.outboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any {
5520
s := make([]*QueueOutboundElement, 0, device.BatchSize())
5621
return &QueueOutboundElementsContainer{elems: s}
5722
})
58-
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
23+
device.pool.messageBuffers = waitpool.New(PreallocatedBuffersPerPool, func() any {
5924
return new([MaxMessageSize]byte)
6025
})
61-
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
26+
device.pool.inboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any {
6227
return new(QueueInboundElement)
6328
})
64-
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
29+
device.pool.outboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any {
6530
return new(QueueOutboundElement)
6631
})
6732
}

device/queueconstants_ios.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ var (
1515
QueueOutboundSize = 1024
1616
QueueInboundSize = 1024
1717
QueueHandshakeSize = 1024
18-
PreallocatedBuffersPerPool uint32 = 1024
18+
PreallocatedBuffersPerPool = 1024
1919
)
2020

2121
const MaxSegmentSize = 1700
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
66
*/
77

8-
package device
8+
package waitpool
99

1010
const raceEnabled = false
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,6 @@
55
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
66
*/
77

8-
package device
8+
package waitpool
99

1010
const raceEnabled = true

waitpool/waitpool.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/* SPDX-License-Identifier: MIT
2+
*
3+
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
4+
*/
5+
6+
// Package waitpool provides a sync.Pool wrapper that caps the number of
7+
// concurrently checked-out elements, blocking Get when the cap is reached.
8+
package waitpool
9+
10+
import (
11+
"sync"
12+
)
13+
14+
// WaitPool is a sync.Pool with an optional concurrency cap. When max > 0,
15+
// Get blocks once max elements are checked out until a corresponding Put
16+
// returns one. When max == 0 there is no cap and Get never blocks.
17+
type WaitPool struct {
18+
pool sync.Pool
19+
cond sync.Cond
20+
lock sync.Mutex
21+
count int // Get calls not yet Put back
22+
max int
23+
}
24+
25+
// New returns a WaitPool with the given concurrency cap and constructor.
26+
// A max of 0 (or negative) disables the cap.
27+
func New(max int, newFn func() any) *WaitPool {
28+
if max < 0 {
29+
max = 0
30+
}
31+
p := &WaitPool{pool: sync.Pool{New: newFn}, max: max}
32+
p.cond = sync.Cond{L: &p.lock}
33+
return p
34+
}
35+
36+
// Get returns an element from the pool, blocking if the concurrency cap is reached.
37+
func (p *WaitPool) Get() any {
38+
if p.max != 0 {
39+
p.lock.Lock()
40+
for p.count >= p.max {
41+
p.cond.Wait()
42+
}
43+
p.count++
44+
p.lock.Unlock()
45+
}
46+
return p.pool.Get()
47+
}
48+
49+
// Put returns an element to the pool and unblocks one waiting Get if any.
50+
func (p *WaitPool) Put(x any) {
51+
p.pool.Put(x)
52+
if p.max == 0 {
53+
return
54+
}
55+
p.lock.Lock()
56+
defer p.lock.Unlock()
57+
p.count--
58+
p.cond.Signal()
59+
}
Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
44
*/
55

6-
package device
6+
package waitpool
77

88
import (
99
"math/rand"
@@ -30,9 +30,9 @@ func TestWaitPool(t *testing.T) {
3030
if workers-4 <= 0 {
3131
t.Skip("Not enough cores")
3232
}
33-
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
33+
p := New(workers-4, func() any { return make([]byte, 16) })
3434
wg.Add(workers)
35-
var max atomic.Uint32
35+
var max atomic.Int64
3636
updateMax := func() {
3737
p.lock.Lock()
3838
count := p.count
@@ -42,10 +42,10 @@ func TestWaitPool(t *testing.T) {
4242
}
4343
for {
4444
old := max.Load()
45-
if count <= old {
45+
if int64(count) <= old {
4646
break
4747
}
48-
if max.CompareAndSwap(old, count) {
48+
if max.CompareAndSwap(old, int64(count)) {
4949
break
5050
}
5151
}
@@ -65,7 +65,7 @@ func TestWaitPool(t *testing.T) {
6565
}()
6666
}
6767
wg.Wait()
68-
if max.Load() != p.max {
68+
if max.Load() != int64(p.max) {
6969
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max.Load(), p.max)
7070
}
7171
}
@@ -78,7 +78,7 @@ func BenchmarkWaitPool(b *testing.B) {
7878
if workers-4 <= 0 {
7979
b.Skip("Not enough cores")
8080
}
81-
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
81+
p := New(workers-4, func() any { return make([]byte, 16) })
8282
wg.Add(workers)
8383
b.ResetTimer()
8484
for i := 0; i < workers; i++ {
@@ -102,7 +102,7 @@ func BenchmarkWaitPoolEmpty(b *testing.B) {
102102
if workers-4 <= 0 {
103103
b.Skip("Not enough cores")
104104
}
105-
p := NewWaitPool(0, func() any { return make([]byte, 16) })
105+
p := New(0, func() any { return make([]byte, 16) })
106106
wg.Add(workers)
107107
b.ResetTimer()
108108
for i := 0; i < workers; i++ {

0 commit comments

Comments
 (0)