Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/tailscale/wireguard-go/ratelimiter"
"github.com/tailscale/wireguard-go/rwcancel"
"github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/waitpool"
)

type Device struct {
Expand Down Expand Up @@ -71,11 +72,11 @@ type Device struct {
cookieChecker CookieChecker

pool struct {
inboundElementsContainer *WaitPool
outboundElementsContainer *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
inboundElementsContainer *waitpool.WaitPool
outboundElementsContainer *waitpool.WaitPool
messageBuffers *waitpool.WaitPool
inboundElements *waitpool.WaitPool
outboundElements *waitpool.WaitPool
}

queue struct {
Expand Down
49 changes: 7 additions & 42 deletions device/pools.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,61 +7,26 @@ package device

import (
"sync"
)

type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count uint32 // Get calls not yet Put back
max uint32
}

func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}

func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count >= p.max {
p.cond.Wait()
}
p.count++
p.lock.Unlock()
}
return p.pool.Get()
}

func (p *WaitPool) Put(x any) {
p.pool.Put(x)
if p.max == 0 {
return
}
p.lock.Lock()
defer p.lock.Unlock()
p.count--
p.cond.Signal()
}
"github.com/tailscale/wireguard-go/waitpool"
)

func (device *Device) PopulatePools() {
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.outboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.messageBuffers = waitpool.New(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any {
return new(QueueInboundElement)
})
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.outboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any {
return new(QueueOutboundElement)
})
}
Expand Down
10 changes: 5 additions & 5 deletions device/queueconstants_ios.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ package device
// These are vars instead of consts, because heavier network extensions might want to reduce
// them further.
var (
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
PreallocatedBuffersPerPool uint32 = 1024
QueueStagedSize = 128
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
PreallocatedBuffersPerPool = 1024
)

const MaxSegmentSize = 1700
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/

package device
package waitpool

const raceEnabled = false
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/

package device
package waitpool

const raceEnabled = true
59 changes: 59 additions & 0 deletions waitpool/waitpool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/

// Package waitpool provides a sync.Pool wrapper that caps the number of
// concurrently checked-out elements, blocking Get when the cap is reached.
package waitpool

import (
"sync"
)

// WaitPool is a sync.Pool with an optional concurrency cap. When max > 0,
// Get blocks once max elements are checked out until a corresponding Put
// returns one. When max == 0 there is no cap and Get never blocks.
type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count int // Get calls not yet Put back
max int
}

// New returns a WaitPool with the given concurrency cap and constructor.
// A max of 0 (or negative) disables the cap.
func New(max int, newFn func() any) *WaitPool {
if max < 0 {
max = 0
}
p := &WaitPool{pool: sync.Pool{New: newFn}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}

// Get returns an element from the pool, blocking if the concurrency cap is reached.
func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count >= p.max {
p.cond.Wait()
}
p.count++
p.lock.Unlock()
}
return p.pool.Get()
}

// Put returns an element to the pool and unblocks one waiting Get if any.
func (p *WaitPool) Put(x any) {
p.pool.Put(x)
if p.max == 0 {
return
}
p.lock.Lock()
defer p.lock.Unlock()
p.count--
p.cond.Signal()
Comment thread
illotum marked this conversation as resolved.
}
16 changes: 8 additions & 8 deletions device/pools_test.go → waitpool/waitpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
*/

package device
package waitpool

import (
"math/rand"
Expand All @@ -30,9 +30,9 @@ func TestWaitPool(t *testing.T) {
if workers-4 <= 0 {
t.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
p := New(workers-4, func() any { return make([]byte, 16) })
wg.Add(workers)
var max atomic.Uint32
var max atomic.Int64
updateMax := func() {
p.lock.Lock()
count := p.count
Expand All @@ -42,10 +42,10 @@ func TestWaitPool(t *testing.T) {
}
for {
old := max.Load()
if count <= old {
if int64(count) <= old {
break
}
if max.CompareAndSwap(old, count) {
if max.CompareAndSwap(old, int64(count)) {
break
}
}
Expand All @@ -65,7 +65,7 @@ func TestWaitPool(t *testing.T) {
}()
}
wg.Wait()
if max.Load() != p.max {
if max.Load() != int64(p.max) {
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max.Load(), p.max)
}
}
Expand All @@ -78,7 +78,7 @@ func BenchmarkWaitPool(b *testing.B) {
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
p := New(workers-4, func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
Expand All @@ -102,7 +102,7 @@ func BenchmarkWaitPoolEmpty(b *testing.B) {
if workers-4 <= 0 {
b.Skip("Not enough cores")
}
p := NewWaitPool(0, func() any { return make([]byte, 16) })
p := New(0, func() any { return make([]byte, 16) })
wg.Add(workers)
b.ResetTimer()
for i := 0; i < workers; i++ {
Expand Down
Loading