diff --git a/conn/bind_std.go b/conn/bind_std.go index c13891e67..39363448a 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -16,6 +16,7 @@ import ( "sync" "syscall" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -233,13 +234,14 @@ func (s *StdNetBind) receiveIP( br batchReader, conn *net.UDPConn, rxOffload bool, - bufs [][]byte, - sizes []int, + bufs []iobuf.View, eps []Endpoint, ) (n int, err error) { msgs := s.getMessages() + // TODO: placeholder until bind implements right-sized buffers. + iobuf.EnsureAllocated(bufs) for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].Buffers[0] = bufs[i].Bytes (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } defer s.putMessages(msgs) @@ -271,8 +273,8 @@ func (s *StdNetBind) receiveIP( } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] - sizes[i] = msg.N - if sizes[i] == 0 { + bufs[i].Bytes = bufs[i].Bytes[:msg.N] + if len(bufs[i].Bytes) == 0 { continue } addrPort := msg.Addr.(*net.UDPAddr).AddrPort() @@ -284,14 +286,14 @@ func (s *StdNetBind) receiveIP( } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 254952f0a..88dcc6c2f 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -5,6 +5,7 @@ import ( "net" "testing" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/net/ipv6" ) @@ -15,15 +16,14 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { t.Fatal(err) } bind.Close() - bufs := make([][]byte, 1) - bufs[0] = make([]byte, 1) - sizes := make([]int, 1) + bufs := make([]iobuf.View, 1) + bufs[0] = iobuf.View{Bytes: make([]byte, 1)} eps := make([]Endpoint, 1) for _, fn := range fns { // The ReceiveFuncs must not access conn-related fields on StdNetBind // unguarded. Close() nils the conn-related fields resulting in a panic // if they violate the mutex. - fn(bufs, sizes, eps) + fn(bufs, eps) } } diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 737b475e1..ba2b45bd4 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/windows" "github.com/tailscale/wireguard-go/conn/winrio" + "github.com/tailscale/wireguard-go/iobuf" ) const ( @@ -416,20 +417,22 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv4(bufs []iobuf.View, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) - sizes[0] = n + iobuf.EnsureAllocated(bufs[:1]) + n, ep, err := bind.v4.Receive(bufs[0].Bytes, &bind.isOpen) + bufs[0].Bytes = bufs[0].Bytes[:n] eps[0] = ep return 1, err } -func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv6(bufs []iobuf.View, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) - sizes[0] = n + iobuf.EnsureAllocated(bufs[:1]) + n, ep, err := bind.v6.Receive(bufs[0].Bytes, &bind.isOpen) + bufs[0].Bytes = bufs[0].Bytes[:n] eps[0] = ep return 1, err } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 741b776c4..e75749982 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -13,6 +13,7 @@ import ( "os" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" ) type ChannelBind struct { @@ -94,13 +95,14 @@ func (c *ChannelBind) BatchSize() int { return 1 } func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + return func(bufs []iobuf.View, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: return 0, net.ErrClosed case rx := <-ch: - copied := copy(bufs[0], rx) - sizes[0] = copied + iobuf.EnsureAllocated(bufs[:1]) + n := copy(bufs[0].Bytes, rx) + bufs[0].Bytes = bufs[0].Bytes[:n] eps[0] = c.target6 return 1, nil } diff --git a/conn/conn.go b/conn/conn.go index f1781614d..a0a8ffa36 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -13,19 +13,20 @@ import ( "reflect" "runtime" "strings" + + "github.com/tailscale/wireguard-go/iobuf" ) const ( IdealBatchSize = 128 // maximum number of packets handled per read and write ) -// A ReceiveFunc receives at least one packet from the network and writes them -// into packets. On a successful read it returns the number of elements of -// sizes, packets, and endpoints that should be evaluated. Some elements of -// sizes may be zero, and callers should ignore them. Callers must pass a sizes -// and eps slice with a length greater than or equal to the length of packets. -// These lengths must not exceed the length of the associated Bind.BatchSize(). -type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) +// A ReceiveFunc receives at least one packet from the network into bufs. +// On a successful read it returns the number of elements of bufs and eps +// that should be evaluated. Callers must pass an eps slice with a length +// greater than or equal to the length of bufs. These lengths must not +// exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(bufs []iobuf.View, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // diff --git a/conn/conn_test.go b/conn/conn_test.go index c6194ee0c..a39ac8fe5 100644 --- a/conn/conn_test.go +++ b/conn/conn_test.go @@ -7,11 +7,13 @@ package conn import ( "testing" + + "github.com/tailscale/wireguard-go/iobuf" ) func TestPrettyName(t *testing.T) { var ( - recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + recvFunc ReceiveFunc = func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { return } ) const want = "TestPrettyName" diff --git a/device/channels.go b/device/channels.go index 18edc9ebb..de4d9fd49 100644 --- a/device/channels.go +++ b/device/channels.go @@ -8,6 +8,8 @@ package device import ( "runtime" "sync" + + "github.com/tailscale/wireguard-go/iobuf" ) // An outboundQueue is a channel of QueueOutboundElements awaiting encryption. @@ -90,7 +92,7 @@ func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue { } func (device *Device) needsInboundQueueFinalizer() bool { - return device.pool.messageBuffers.HasAccounting() || + return iobuf.HasAccounting() || device.pool.inboundElements.HasAccounting() || device.pool.inboundElementsContainer.HasAccounting() } @@ -101,7 +103,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { case elemsContainer := <-q.c: elemsContainer.filling.Wait() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -131,7 +133,7 @@ func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue { } func (device *Device) needsOutboundQueueFinalizer() bool { - return device.pool.messageBuffers.HasAccounting() || + return iobuf.HasAccounting() || device.pool.outboundElements.HasAccounting() || device.pool.outboundElementsContainer.HasAccounting() } @@ -142,7 +144,7 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { case elemsContainer := <-q.c: elemsContainer.filling.Wait() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/device/channels_test.go b/device/channels_test.go index a35a434a3..316cef5b6 100644 --- a/device/channels_test.go +++ b/device/channels_test.go @@ -8,6 +8,7 @@ package device import ( "testing" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/waitpool" ) @@ -15,8 +16,12 @@ func TestAutodrainingQueueFinalizerNeedTracksPoolAccounting(t *testing.T) { unbounded := func() *waitpool.WaitPool { return waitpool.New(0, func() any { return nil }) } bounded := func() *waitpool.WaitPool { return waitpool.New(1, func() any { return nil }) } + // Force the default raw pool unbounded for the bulk of the test. + origPool := iobuf.DefaultRawPool + iobuf.DefaultRawPool = iobuf.NewRawPool(0) + t.Cleanup(func() { iobuf.DefaultRawPool = origPool }) + device := &Device{} - device.pool.messageBuffers = unbounded() device.pool.inboundElements = unbounded() device.pool.inboundElementsContainer = unbounded() device.pool.outboundElements = unbounded() @@ -47,11 +52,11 @@ func TestAutodrainingQueueFinalizerNeedTracksPoolAccounting(t *testing.T) { } device.pool.outboundElementsContainer = unbounded() - device.pool.messageBuffers = bounded() + iobuf.DefaultRawPool = iobuf.NewRawPool(1) if !device.needsInboundQueueFinalizer() { - t.Fatal("bounded message buffer pool should need inbound queue finalizer") + t.Fatal("bounded raw buffer pool should need inbound queue finalizer") } if !device.needsOutboundQueueFinalizer() { - t.Fatal("bounded message buffer pool should need outbound queue finalizer") + t.Fatal("bounded raw buffer pool should need outbound queue finalizer") } } diff --git a/device/constants.go b/device/constants.go index 92c3bdea8..81e40ddfa 100644 --- a/device/constants.go +++ b/device/constants.go @@ -7,6 +7,8 @@ package device import ( "time" + + "github.com/tailscale/wireguard-go/iobuf" ) /* Specification constants */ @@ -27,9 +29,9 @@ const ( ) const ( - MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) - MaxMessageSize = MaxSegmentSize // maximum size of transport message - MaxContentSize = MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content + MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) + MaxMessageSize = iobuf.MaxSegmentSize // maximum size of transport message + MaxContentSize = iobuf.MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content ) /* Implementation constants */ diff --git a/device/device.go b/device/device.go index e25feeae2..8c560fb0a 100644 --- a/device/device.go +++ b/device/device.go @@ -79,7 +79,6 @@ type Device struct { pool struct { inboundElementsContainer *waitpool.WaitPool outboundElementsContainer *waitpool.WaitPool - messageBuffers *waitpool.WaitPool inboundElements *waitpool.WaitPool outboundElements *waitpool.WaitPool } diff --git a/device/device_test.go b/device/device_test.go index e44342170..0c07a42b0 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -22,6 +22,7 @@ import ( "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn/bindtest" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/tun/tuntest" ) @@ -437,7 +438,7 @@ type fakeTUNDeviceSized struct { } func (t *fakeTUNDeviceSized) File() *os.File { return nil } -func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (t *fakeTUNDeviceSized) Read(bufs []iobuf.View, offset int) (n int, err error) { return 0, nil } func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } diff --git a/device/pools.go b/device/pools.go index f6ed97063..498771577 100644 --- a/device/pools.go +++ b/device/pools.go @@ -6,25 +6,23 @@ package device import ( + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/waitpool" ) func (device *Device) PopulatePools() { - device.pool.inboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) return &QueueInboundElementsContainer{elems: s} }) - device.pool.outboundElementsContainer = waitpool.New(PreallocatedBuffersPerPool, func() any { + device.pool.outboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any { s := make([]*QueueOutboundElement, 0, device.BatchSize()) return &QueueOutboundElementsContainer{elems: s} }) - device.pool.messageBuffers = waitpool.New(PreallocatedBuffersPerPool, func() any { - return new([MaxMessageSize]byte) - }) - device.pool.inboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any { return new(QueueInboundElement) }) - device.pool.outboundElements = waitpool.New(PreallocatedBuffersPerPool, func() any { + device.pool.outboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any { return new(QueueOutboundElement) }) } @@ -55,14 +53,6 @@ func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsConta device.pool.outboundElementsContainer.Put(c) } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - device.pool.messageBuffers.Put(msg) -} - func (device *Device) GetInboundElement() *QueueInboundElement { return device.pool.inboundElements.Get().(*QueueInboundElement) } @@ -78,5 +68,6 @@ func (device *Device) GetOutboundElement() *QueueOutboundElement { func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { elem.clearPointers() + elem.nonce = 0 device.pool.outboundElements.Put(elem) } diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index bab9625c4..8f564db55 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -10,10 +10,8 @@ import "github.com/tailscale/wireguard-go/conn" /* Reduce memory consumption for Android */ const ( - QueueStagedSize = conn.IdealBatchSize - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 - PreallocatedBuffersPerPool = 4096 + QueueStagedSize = conn.IdealBatchSize + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index 9749cb789..56b1ab995 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -10,10 +10,8 @@ package device import "github.com/tailscale/wireguard-go/conn" const ( - QueueStagedSize = conn.IdealBatchSize - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth + QueueStagedSize = conn.IdealBatchSize + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go index a967d71e0..258e1eed3 100644 --- a/device/queueconstants_ios.go +++ b/device/queueconstants_ios.go @@ -11,11 +11,8 @@ package device // These are vars instead of consts, because heavier network extensions might want to reduce // them further. var ( - QueueStagedSize = 128 - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - PreallocatedBuffersPerPool = 1024 + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) - -const MaxSegmentSize = 1700 diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go index 1eee32ba1..169c439de 100644 --- a/device/queueconstants_windows.go +++ b/device/queueconstants_windows.go @@ -6,10 +6,8 @@ package device const ( - QueueStagedSize = 128 - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 2048 - 32 // largest possible UDP datagram - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) diff --git a/device/receive.go b/device/receive.go index 9fd2aec70..b58b7a3d9 100644 --- a/device/receive.go +++ b/device/receive.go @@ -15,6 +15,7 @@ import ( "time" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -24,11 +25,11 @@ type QueueHandshakeElement struct { msgType uint32 packet []byte endpoint conn.Endpoint - buffer *[MaxMessageSize]byte + buffer iobuf.View } type QueueInboundElement struct { - buffer *[MaxMessageSize]byte + buffer iobuf.View packet []byte counter uint64 keypair *Keypair @@ -50,7 +51,7 @@ type QueueInboundElementsContainer struct { // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueInboundElement) clearPointers() { - elem.buffer = nil + elem.buffer = iobuf.View{} elem.packet = nil elem.keypair = nil elem.endpoint = nil @@ -90,31 +91,18 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive // receive datagrams until conn is closed var ( - bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) - bufs = make([][]byte, maxBatchSize) + bufs = make([]iobuf.View, maxBatchSize) // nil entries; recv allocates err error - sizes = make([]int, maxBatchSize) count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for i := range bufsArrs { - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] - } - - defer func() { - for i := 0; i < maxBatchSize; i++ { - if bufsArrs[i] != nil { - device.PutMessageBuffer(bufsArrs[i]) - } - } - }() + defer iobuf.ReleaseAll(bufs) for { - count, err = recv(bufs, sizes, endpoints) + count, err = recv(bufs, endpoints) if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -133,14 +121,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive deathSpiral = 0 // handle each packet in the batch - for i, size := range sizes[:count] { - if size < MinMessageSize { + for i := 0; i < count; i++ { + if len(bufs[i].Bytes) < MinMessageSize { continue } // check size of packet - packet := bufsArrs[i][:size] + packet := bufs[i].Bytes msgType := binary.LittleEndian.Uint32(packet[:4]) switch msgType { @@ -176,7 +164,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive peer := value.peer elem := device.GetInboundElement() elem.packet = packet - elem.buffer = bufsArrs[i] + elem.buffer = bufs[i].Claim() elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 @@ -187,8 +175,6 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] continue // otherwise it is a fixed size & handshake related packet @@ -213,18 +199,19 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive continue } + claimed := bufs[i].Claim() select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, - buffer: bufsArrs[i], + buffer: claimed, packet: packet, endpoint: endpoints[i], }: - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] default: + claimed.Release() } } + iobuf.ReleaseAll(bufs[:count]) // release unclaimed for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { elemsContainer.filling.Add(1) @@ -232,7 +219,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive device.queue.decryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -429,7 +416,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() } } @@ -548,7 +535,7 @@ func (peer *Peer) processInboundContainer(elemsContainer *QueueInboundElementsCo continue } - scratch = append(scratch, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) + scratch = append(scratch, elem.buffer.Bytes[:MessageTransportOffsetContent+len(elem.packet)]) } peer.rxBytes.Add(rxBytesLen) @@ -568,7 +555,7 @@ func (peer *Peer) processInboundContainer(elemsContainer *QueueInboundElementsCo } } for _, elem := range elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutInboundElement(elem) } } diff --git a/device/send.go b/device/send.go index 497c39930..dee82591c 100644 --- a/device/send.go +++ b/device/send.go @@ -16,6 +16,7 @@ import ( "time" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" @@ -47,8 +48,8 @@ import ( */ type QueueOutboundElement struct { - buffer *[MaxMessageSize]byte // slice holding the packet data - // packet is always a slice of "buffer". The starting offset in buffer + buffer iobuf.View + // packet is always a slice of buf. The starting offset in buf // is either: // a) MessageEncapsulatingTransportSize+MessageTransportHeaderSize (plaintext) // b) 0 (post-encryption) @@ -68,20 +69,12 @@ type QueueOutboundElementsContainer struct { elems []*QueueOutboundElement } -func (device *Device) NewOutboundElement() *QueueOutboundElement { - elem := device.GetOutboundElement() - elem.buffer = device.GetMessageBuffer() - elem.nonce = 0 - // keypair and peer were cleared (if necessary) by clearPointers. - return elem -} - // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueOutboundElement) clearPointers() { - elem.buffer = nil + elem.buffer = iobuf.View{} elem.packet = nil elem.keypair = nil elem.peer = nil @@ -91,14 +84,15 @@ func (elem *QueueOutboundElement) clearPointers() { */ func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { - elem := peer.device.NewOutboundElement() + elem := peer.device.GetOutboundElement() + elem.buffer = iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageTransportSize)} elemsContainer := peer.device.GetOutboundElementsContainer() elemsContainer.elems = append(elemsContainer.elems, elem) select { case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) peer.device.PutOutboundElementsContainer(elemsContainer) } @@ -134,15 +128,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize)} + defer buf.Release() + packet := buf.Bytes[MessageEncapsulatingTransportSize:] _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Bytes}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -164,8 +159,9 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize)} + defer buf.Release() + packet := buf.Bytes[MessageEncapsulatingTransportSize:] _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) @@ -180,7 +176,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Bytes}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -197,11 +193,12 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize)} + defer buf.Release() + packet := buf.Bytes[MessageEncapsulatingTransportSize:] _ = reply.marshal(packet) // TODO: allocation could be avoided - device.net.bind.Send([][]byte{buf}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + device.net.bind.Send([][]byte{buf.Bytes}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) return nil } @@ -229,57 +226,41 @@ func (device *Device) RoutineReadFromTUN() { var ( batchSize = device.BatchSize() readErr error - elems = make([]*QueueOutboundElement, batchSize) - bufs = make([][]byte, batchSize) + bufs = make([]iobuf.View, batchSize) elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 - sizes = make([]int, batchSize) offset = MessageEncapsulatingTransportSize + MessageTransportHeaderSize ) - for i := range elems { - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] - } - - defer func() { - for _, elem := range elems { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - } - } - }() + defer iobuf.ReleaseAll(bufs) for { - // read packets - count, readErr = device.tun.device.Read(bufs, sizes, offset) + count, readErr = device.tun.device.Read(bufs, offset) + for i := 0; i < count; i++ { - if sizes[i] < 1 { + packet := bufs[i].Bytes[offset:] + if len(packet) < 1 { continue } - elem := elems[i] - elem.packet = bufs[i][offset : offset+sizes[i]] - // lookup peer var peer *Peer - switch elem.packet[0] >> 4 { + switch packet[0] >> 4 { case 4: - if len(elem.packet) < ipv4.HeaderLen { + if len(packet) < ipv4.HeaderLen { continue } - src := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) - dst := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom4([4]byte(packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) + dst := netip.AddrFrom4([4]byte(packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) case 6: - if len(elem.packet) < ipv6.HeaderLen { + if len(packet) < ipv6.HeaderLen { continue } - src := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) - dst := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom16([16]byte(packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) + dst := netip.AddrFrom16([16]byte(packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) default: device.log.Verbosef("Received packet with unknown IP version") @@ -288,15 +269,19 @@ func (device *Device) RoutineReadFromTUN() { if peer == nil { continue } + + elem := device.GetOutboundElement() + elem.packet = packet + elem.buffer = bufs[i].Claim() + elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] } + iobuf.ReleaseAll(bufs[:count]) // release unclaimed for peer, elemsForPeer := range elemsByPeer { if peer.isRunning.Load() { @@ -304,7 +289,7 @@ func (device *Device) RoutineReadFromTUN() { peer.SendStagedPackets() } else { for _, elem := range elemsForPeer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsForPeer) @@ -341,7 +326,7 @@ func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { select { case tooOld := <-peer.queue.staged: for _, elem := range tooOld.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(tooOld) @@ -402,7 +387,7 @@ top: peer.device.queue.encryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -422,7 +407,7 @@ func (peer *Peer) FlushStagedPackets() { select { case elemsContainer := <-peer.queue.staged: for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -462,7 +447,7 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] + header := elem.buffer.Bytes[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] @@ -487,7 +472,8 @@ func (device *Device) RoutineEncryption(id int) { ) // re-slice packet to include encapsulating transport space - elem.packet = elem.buffer[:MessageEncapsulatingTransportSize+len(elem.packet)] + elem.buffer.Bytes = elem.buffer.Bytes[:MessageEncapsulatingTransportSize+len(elem.packet)] + elem.packet = elem.buffer.Bytes } elemsContainer.filling.Done() } @@ -544,7 +530,7 @@ func (peer *Peer) processOutboundContainer(elemsContainer *QueueOutboundElements // TODO: rework peer shutdown order to ensure // that we never accidentally keep timers alive longer than necessary. for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } return @@ -566,7 +552,7 @@ func (peer *Peer) processOutboundContainer(elemsContainer *QueueOutboundElements peer.timersDataSent() } for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } if err != nil { diff --git a/iobuf/constants.go b/iobuf/constants.go new file mode 100644 index 000000000..477a9c9bd --- /dev/null +++ b/iobuf/constants.go @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxBufferSize = MaxSegmentSize // the largest buffer that I/O may attempt to read or write. +) diff --git a/iobuf/constants_android.go b/iobuf/constants_android.go new file mode 100644 index 000000000..bfb556f44 --- /dev/null +++ b/iobuf/constants_android.go @@ -0,0 +1,13 @@ +//go:build android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxSegmentSize = 2200 + MaxPooledBuffers = 4096 +) diff --git a/iobuf/constants_default.go b/iobuf/constants_default.go new file mode 100644 index 000000000..0ac3b6779 --- /dev/null +++ b/iobuf/constants_default.go @@ -0,0 +1,13 @@ +//go:build !android && !ios && !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxSegmentSize = (1 << 16) - 1 + MaxPooledBuffers = 0 // Disable and allow for infinite memory growth +) diff --git a/iobuf/constants_ios.go b/iobuf/constants_ios.go new file mode 100644 index 000000000..2cb7b481c --- /dev/null +++ b/iobuf/constants_ios.go @@ -0,0 +1,14 @@ +//go:build ios + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +var ( + MaxPooledBuffers = 1024 // Var to allow further reduction. Recreate [DefaultRawPool] if changed. +) + +const MaxSegmentSize = 1700 diff --git a/iobuf/constants_windows.go b/iobuf/constants_windows.go new file mode 100644 index 000000000..f5f48c845 --- /dev/null +++ b/iobuf/constants_windows.go @@ -0,0 +1,13 @@ +//go:build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxSegmentSize = 2048 - 32 + MaxPooledBuffers = 0 // Disable and allow for infinite memory growth +) diff --git a/iobuf/raw.go b/iobuf/raw.go new file mode 100644 index 000000000..91a92399c --- /dev/null +++ b/iobuf/raw.go @@ -0,0 +1,68 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +import ( + "unsafe" + + "github.com/tailscale/wireguard-go/waitpool" +) + +var _ Recycler = (*RawPool)(nil) + +// Raw is the fundamental byte array. +type Raw [MaxBufferSize]byte + +// RawPool wraps [waitpool.WaitPool] of [Raw] buffers +// to configure their return via [RawPool.Recycle]. +type RawPool struct { + *waitpool.WaitPool +} + +func (p *RawPool) Get() *Raw { + return p.WaitPool.Get().(*Raw) +} + +// Recycle returns the backing array of v to the pool. +func (p *RawPool) Recycle(goPtr unsafe.Pointer, _ uintptr) { + arr := (*Raw)(goPtr) + p.Put(arr) +} + +func NewRawPool(size int) *RawPool { + return &RawPool{waitpool.New(size, func() any { + return new(Raw) + })} +} + +// DefaultRawPool is used for package-level [Init] and [EnsureAllocated]. +var DefaultRawPool = NewRawPool(MaxPooledBuffers) + +// HasAccounting reports whether the default raw-buffer pool enforces a +// concurrency cap. Callers use it to decide whether allocation can block, +// and therefore whether queue finalizers are needed to drain blocked +// producers when an autodraining queue is GC'd. +func HasAccounting() bool { + return DefaultRawPool.WaitPool.HasAccounting() +} + +// EnsureAllocated fills zero-valued Views from the [DefaultRawPool]. +func EnsureAllocated(bufs []View) { + for i := range bufs { + if bufs[i].Bytes == nil { + Init(&bufs[i]) + } + } +} + +// Init initializes a [View] in-place with a fresh backing from the pool. +// Sets Bytes to the full backing array. +func Init(b *View) { + arr := DefaultRawPool.Get() + b.Recycler = DefaultRawPool + b.BackingGo = unsafe.Pointer(arr) + b.Bytes = arr[:] +} diff --git a/iobuf/view.go b/iobuf/view.go new file mode 100644 index 000000000..d670981a9 --- /dev/null +++ b/iobuf/view.go @@ -0,0 +1,66 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +// Package iobuf provides pooled packet buffers for the I/O pipeline. +// Each [View] carries one packet and a recycle function that returns +// its backing storage to the originating pool on [Release]. +package iobuf + +import "unsafe" + +// Recycler returns the backing of a [View] to its source for reuse. +// Implementations read whichever of [View.BackingGo] or [View.BackingExt] +// matches their allocation strategy. +type Recycler interface { + Recycle(unsafe.Pointer, uintptr) +} + +// RecycleFunc is a function adapter for [Recycler]. +type RecycleFunc func(unsafe.Pointer, uintptr) + +func (f RecycleFunc) Recycle(goPtr unsafe.Pointer, extPtr uintptr) { f(goPtr, extPtr) } + +// View is the packet envelope. Meant to be a value type, +// allocated once per goroutine and reused across read cycles. +// +// Exactly one of BackingGo / BackingExt is set for a managed View +// (Recycler != nil). Both are zero for unmanaged Views. +type View struct { + Recycler Recycler + + // BackingGo holds a Go-allocated backing object. Zero otherwise. + BackingGo unsafe.Pointer + + // BackingExt holds an opaque address into off-heap memory (e.g. a + // WinRio ring slot, an AF_XDP region). Zero otherwise. + BackingExt uintptr + + // Bytes holds the bounded packet data. Cut from the backing, + // it may be re-sliced by the caller. Do not append() on this slice. + // Nil for uninitialized Views. + Bytes []byte +} + +// Release returns the backing data to its source and zeros the View. +func (b *View) Release() { + if b.Recycler != nil { + b.Recycler.Recycle(b.BackingGo, b.BackingExt) + } + *b = View{} +} + +// Claim transfers ownership: returns a copy of the View and zeros the source. +func (b *View) Claim() View { + c := *b + *b = View{} + return c +} + +// ReleaseAll releases each View in the slice. +func ReleaseAll(bufs []View) { + for i := range bufs { + bufs[i].Release() + } +} diff --git a/iobuf/view_test.go b/iobuf/view_test.go new file mode 100644 index 000000000..ff136bd79 --- /dev/null +++ b/iobuf/view_test.go @@ -0,0 +1,124 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +import ( + "testing" + "unsafe" + + "github.com/tailscale/wireguard-go/waitpool" +) + +func TestViewRelease(t *testing.T) { + var released bool + v := View{ + Recycler: RecycleFunc(func(unsafe.Pointer, uintptr) { released = true }), + BackingExt: 1, + Bytes: make([]byte, 1), + } + v.Release() + if !released { + t.Fatal("Recycler was not called on Release") + } + if v.Recycler != nil || v.BackingGo != nil || v.BackingExt != 0 || v.Bytes != nil { + t.Fatal("View not zeroed after Release") + } + // Double release is safe. + v.Release() + // Unmanaged Release is a no-op. + u := View{Bytes: make([]byte, 1)} + u.Release() // must not panic + if u.Bytes != nil { + t.Fatal("unmanaged View not zeroed after Release") + } +} + +func TestViewClaim(t *testing.T) { + var released bool + orig := View{ + Recycler: RecycleFunc(func(unsafe.Pointer, uintptr) { released = true }), + BackingExt: 1, + Bytes: make([]byte, 1), + } + claimed := orig.Claim() + // Source must be zeroed. + if orig.Recycler != nil || orig.BackingGo != nil || orig.BackingExt != 0 || orig.Bytes != nil { + t.Fatal("source not zeroed after Claim") + } + // Original Release is a no-op + orig.Release() + if released { + t.Fatal("Recycler called when releasing moved copy") + } + // Claimed copy must carry ownership. + if claimed.Recycler == nil || claimed.BackingExt != 1 { + t.Fatal("claimed copy missing ownership fields") + } + claimed.Release() + if !released { + t.Fatal("Recycler not called when releasing claimed copy") + } +} + +func TestReleaseAll(t *testing.T) { + var count int + r := RecycleFunc(func(unsafe.Pointer, uintptr) { count++ }) + bufs := []View{ + {Recycler: r, BackingExt: 1, Bytes: make([]byte, 1)}, + {}, + {Recycler: r, BackingExt: 2, Bytes: make([]byte, 1)}, + } + ReleaseAll(bufs) + if count != 2 { + t.Fatalf("expected 2 recycle calls, got %d", count) + } + for i, b := range bufs { + if b.Recycler != nil || b.Bytes != nil { + t.Fatalf("bufs[%d] not zeroed", i) + } + } +} + +func TestRawPoolRoundTrip(t *testing.T) { + pool := &RawPool{waitpool.New(0, func() any { + return new(Raw) + })} + arr := pool.Get() + arr[0] = 0xAB + v := View{ + Recycler: pool, + BackingGo: unsafe.Pointer(arr), + Bytes: arr[:], + } + v.Release() + got := pool.Get() + if got[0] != 0xAB { + t.Fatal("pool did not return the same backing array") + } +} + +func TestDefaultPoolHelpers(t *testing.T) { + saved := DefaultRawPool + DefaultRawPool = NewRawPool(0) + defer func() { DefaultRawPool = saved }() + + bufs := make([]View, 3) + Init(&bufs[1]) + bufs[1].Bytes[0] = 0xAB + EnsureAllocated(bufs) // Init 0 and 2 + for i, v := range bufs { + if v.Recycler != DefaultRawPool || v.Bytes == nil || v.BackingGo == nil { + t.Fatalf("bufs[%d] not initialized", i) + } + } + if bufs[1].Bytes[0] != 0xAB { + t.Fatal("EnsureAllocated replaced an already-initialised View") + } + bufs[0].Release() + if bufs[0].Recycler != nil || bufs[0].BackingGo != nil || bufs[0].Bytes != nil { + t.Fatal("View not zeroed after Release") + } +} diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index d8e70bb03..a0f125a4e 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -22,6 +22,7 @@ import ( "syscall" "time" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/dns/dnsmessage" @@ -119,17 +120,19 @@ func (tun *netTun) Events() <-chan tun.Event { return tun.events } -func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { +func (tun *netTun) Read(bufs []iobuf.View, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } - - n, err := view.Read(buf[0][offset:]) + // TODO: If not the offset, could use view.AsSlice() and wrap view.Release() in a [buffer.Recycler]. + // TODO: Allocate view.Size() buffer. + iobuf.EnsureAllocated(bufs[:1]) + n, err := view.Read(bufs[0].Bytes[offset:]) if err != nil { return 0, err } - sizes[0] = n + bufs[0].Bytes = bufs[0].Bytes[:offset+n] return 1, nil } diff --git a/tun/offload.go b/tun/offload.go index 6db437c34..a2cb15e27 100644 --- a/tun/offload.go +++ b/tun/offload.go @@ -3,6 +3,8 @@ package tun import ( "encoding/binary" "fmt" + + "github.com/tailscale/wireguard-go/iobuf" ) // GSOType represents the type of segmentation offload. @@ -73,15 +75,15 @@ const ( ipProtoUDP = 17 ) -// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing -// the size of each element into sizes. It returns the number of buffers +// GSOSplit splits packets from 'in' into one or more entries in outBufs, writing +// each output packet to outBufs[i].Data starting at outOffset. It returns the number of buffers // populated, and/or an error. Callers may pass an 'in' slice that overlaps with // the first element of outBuffers, i.e. &in[0] may be equal to -// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the +// &outBufs[0].Data[outOffset]. GSONone is a valid options.GSOType regardless of the // value of options.NeedsCsum. Length of each outBufs element must be greater // than or equal to the length of 'in', otherwise output may be silently // truncated. -func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { +func GSOSplit(in []byte, options GSOOptions, outBufs []iobuf.View, outOffset int) (int, error) { cSumAt := int(options.CsumStart) + int(options.CsumOffset) if cSumAt+1 >= len(in) { return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) @@ -94,8 +96,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO // Handle the conditions where we are copying a single element to outBuffs. payloadLen := len(in) - int(options.HdrLen) if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { - if len(in) > len(outBufs[0][outOffset:]) { - return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + if len(in) > len(outBufs[0].Bytes[outOffset:]) { + return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0].Bytes[outOffset:])) } if options.NeedsCsum { // The initial value at the checksum offset should be summed with @@ -104,7 +106,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO in[cSumAt], in[cSumAt+1] = 0, 0 binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial)) } - sizes[0] = copy(outBufs[0][outOffset:], in) + n := copy(outBufs[0].Bytes[outOffset:], in) + outBufs[0].Bytes = outBufs[0].Bytes[:outOffset+n] return 1, nil } @@ -164,8 +167,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO } segmentDataLen := nextSegmentEnd - nextSegmentDataAt totalLen := int(options.HdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBufs[i][outOffset:] + outBufs[i].Bytes = outBufs[i].Bytes[:outOffset+totalLen] + out := outBufs[i].Bytes[outOffset:] copy(out, in[:iphLen]) if ipVersion == 4 { diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index b4c9aead0..8cecd803b 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -235,13 +236,12 @@ func Test_handleVirtioRead(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) + out := make([]iobuf.View, conn.IdealBatchSize) for i := range out { - out[i] = make([]byte, 65535) + out[i] = iobuf.View{Bytes: make([]byte, 65535)} } tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + n, err := handleVirtioRead(tt.pktIn, out, offset) if err != nil { if tt.wantErr { return @@ -252,8 +252,8 @@ func Test_handleVirtioRead(t *testing.T) { t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) } for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + if size := len(out[i].Bytes) - offset; tt.wantLens[i] != size { + t.Fatalf("wantLens[%d]: %d != size: %d", i, tt.wantLens[i], size) } } }) diff --git a/tun/offload_test.go b/tun/offload_test.go index 82a37b9cc..33109b72c 100644 --- a/tun/offload_test.go +++ b/tun/offload_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -67,11 +68,10 @@ func Fuzz_GSOSplit(f *testing.F) { }) header.UDP(gsoUDPv6[20:]).Encode(udpFields) - out := make([][]byte, conn.IdealBatchSize) + out := make([]iobuf.View, conn.IdealBatchSize) for i := range out { - out[i] = make([]byte, 65535) + out[i] = iobuf.View{Bytes: make([]byte, 65535)} } - sizes := make([]int, conn.IdealBatchSize) f.Add(gsoTCPv4, int(GSOTCPv4), uint16(40), uint16(20), uint16(16), uint16(100), false) f.Add(gsoUDPv4, int(GSOUDPL4), uint16(28), uint16(20), uint16(6), uint16(100), false) @@ -87,9 +87,9 @@ func Fuzz_GSOSplit(f *testing.F) { GSOSize: gsoSize, NeedsCsum: needsCsum, } - n, _ := GSOSplit(pkt, options, out, sizes, 0) - if n > len(sizes) { - t.Errorf("n (%d) > len(sizes): %d", n, len(sizes)) + n, _ := GSOSplit(pkt, options, out, 0) + if n > len(out) { + t.Errorf("n (%d) > len(out): %d", n, len(out)) } }) } diff --git a/tun/tun.go b/tun/tun.go index 719a60631..5b33b92df 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -7,6 +7,8 @@ package tun import ( "os" + + "github.com/tailscale/wireguard-go/iobuf" ) type Event int @@ -23,10 +25,11 @@ type Device interface { // Read one or more packets from the Device (without any additional headers). // On a successful read it returns the number of packets read, and sets - // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // each buf's length to include the read data. + // Zero-valued entries in bufs are allocated by the implementation. // A nonzero offset can be used to instruct the Device on where to begin // reading into each element of the bufs slice. - Read(bufs [][]byte, sizes []int, offset int) (n int, err error) + Read(bufs []iobuf.View, offset int) (n int, err error) // Write one or more packets to the device (without any additional headers). // On a successful write it returns the number of packets written. A nonzero diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index c9a6c0bc4..f91b11c17 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -16,6 +16,7 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" ) @@ -217,7 +218,7 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { // TODO: the BSDs look very similar in Read() and Write(). They should be // collapsed, with platform-specific files containing the varying parts of // their implementations. @@ -225,12 +226,12 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] - n, err := tun.tunFile.Read(buf[:]) + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.tunFile.Read(bufs[0].Bytes[offset-4:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].Bytes = bufs[0].Bytes[:offset+n-4] return 1, err } } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 7c65fd999..8d5355340 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" ) @@ -333,17 +334,17 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] - n, err := tun.tunFile.Read(buf[:]) + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.tunFile.Read(bufs[0].Bytes[offset-4:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].Bytes = bufs[0].Bytes[:offset+n-4] return 1, err } } diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 8de679d25..6d9fc9ddf 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -18,6 +18,7 @@ import ( "unsafe" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" ) @@ -410,9 +411,9 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { } // handleVirtioRead splits in into bufs, leaving offset bytes at the front of -// each buffer. It mutates sizes to reflect the size of each element of bufs, -// and returns the number of packets read. -func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { +// each buffer. It sets each buffer's Bytes length to reflect the size of each +// element of bufs, and returns the number of packets read. +func handleVirtioRead(in []byte, bufs []iobuf.View, offset int) (int, error) { var hdr virtioNetHdr err := hdr.decode(in) if err != nil { @@ -444,17 +445,19 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e options.HdrLen = options.CsumStart + tcpHLen } - return GSOSplit(in, options, bufs, sizes, offset) + return GSOSplit(in, options, bufs, offset) } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (int, error) { tun.readOpMu.Lock() defer tun.readOpMu.Unlock() select { case err := <-tun.errors: return 0, err default: - readInto := bufs[0][offset:] + // TODO: placeholder until tun implements right-sized buffers. + iobuf.EnsureAllocated(bufs) + readInto := bufs[0].Bytes[offset:] if tun.vnetHdr { readInto = tun.readBuff[:] } @@ -466,9 +469,9 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) return 0, err } if tun.vnetHdr { - return handleVirtioRead(readInto[:n], bufs, sizes, offset) + return handleVirtioRead(readInto[:n], bufs, offset) } else { - sizes[0] = n + bufs[0].Bytes = bufs[0].Bytes[:n+offset] return 1, nil } } diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index ae571b90c..aa25529f4 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" ) @@ -204,17 +205,17 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] - n, err := tun.tunFile.Read(buf[:]) + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.tunFile.Read(bufs[0].Bytes[offset-4:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].Bytes = bufs[0].Bytes[:offset+n-4] return 1, err } } diff --git a/tun/tun_plan9.go b/tun/tun_plan9.go index 7b66eadf6..180b3ada4 100644 --- a/tun/tun_plan9.go +++ b/tun/tun_plan9.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "sync" + + "github.com/tailscale/wireguard-go/iobuf" ) type NativeTun struct { @@ -81,18 +83,19 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - n, err := tun.dataFile.Read(bufs[0][offset:]) - if n == 1 && bufs[0][offset] == 0 { + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.dataFile.Read(bufs[0].Bytes[offset:]) + if n == 1 && bufs[0].Bytes[offset] == 0 { // EOF err = io.EOF n = 0 } - sizes[0] = n + bufs[0].Bytes = bufs[0].Bytes[:offset+n] return 1, err } } diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 34f29805d..e9096f165 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -14,6 +14,7 @@ import ( "time" _ "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/windows" "golang.zx2c4.com/wintun" ) @@ -144,7 +145,7 @@ func (tun *NativeTun) BatchSize() int { // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { tun.running.Add(1) defer tun.running.Done() retry: @@ -161,8 +162,9 @@ retry: switch err { case nil: packetSize := len(packet) - copy(bufs[0][offset:], packet) - sizes[0] = packetSize + iobuf.EnsureAllocated(bufs[:1]) + n := copy(bufs[0].Bytes[offset:], packet) + bufs[0].Bytes = bufs[0].Bytes[:offset+n] tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(packetSize)) return 1, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index e7507c26c..9e1c924e7 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -11,6 +11,7 @@ import ( "net/netip" "os" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" ) @@ -110,13 +111,15 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { +func (t *chTun) Read(bufs []iobuf.View, offset int) (int, error) { select { case <-t.c.closed: return 0, os.ErrClosed case msg := <-t.c.Outbound: - n := copy(packets[0][offset:], msg) - sizes[0] = n + // TODO: Allocate len(msg) buffer. + iobuf.EnsureAllocated(bufs[:1]) + n := copy(bufs[0].Bytes[offset:], msg) + bufs[0].Bytes = bufs[0].Bytes[:offset+n] return 1, nil } }