Skip to content

Commit 1e9e526

Browse files
committed
ipn/amnezia: copy if buf slices overwritten
1 parent e51448f commit 1e9e526

2 files changed

Lines changed: 21 additions & 17 deletions

File tree

intra/ipn/wg/amnezia.go

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -107,23 +107,19 @@ func (a *Amnezia) send(pktptr *[]byte) (ok bool) {
107107
return true
108108
}
109109

110-
func (a *Amnezia) recv(pktptr *[]byte) (ok bool) {
110+
func (a *Amnezia) recv(pkt []byte, upto int) (out []byte, ok bool) {
111111
if a == nil || !a.Set() {
112112
return
113113
}
114-
115-
var typ uint32
116-
pkt := *pktptr
117-
118-
recvLen := len(pkt)
119-
if recvLen < device.MinMessageSize {
114+
if upto < device.MinMessageSize {
120115
return
121116
}
122-
// h := uint16(device.MessageTransportOffsetReceiver)
123117

124-
pkt, typ = a.strip(pkt)
118+
var typ uint32
119+
// h := uint16(device.MessageTransportOffsetReceiver)
120+
pkt, typ = a.strip(pkt[:upto])
121+
strippedSz := len(pkt)
125122

126-
stripLen := len(pkt)
127123
switch typ {
128124
case device.MessageInitiationType, a.H1:
129125
typ = device.MessageInitiationType
@@ -137,12 +133,14 @@ func (a *Amnezia) recv(pktptr *[]byte) (ok bool) {
137133
case device.MessageTransportType, a.H4: // must be default?
138134
typ = device.MessageTransportType
139135
binary.LittleEndian.PutUint32(pkt, device.MessageTransportType)
136+
default:
137+
log.W("wg: %s: amnezia: recv: unexpected type %d", a.id, typ)
138+
// TODO: error?
140139
}
141140

142-
a.logIfNeeded("recv", typ, recvLen, stripLen)
141+
a.logIfNeeded("recv", typ, strippedSz, upto)
143142

144-
*pktptr = pkt
145-
return true
143+
return pkt, true
146144
}
147145

148146
func (a *Amnezia) instate(pkt []byte) ([]byte, uint32) {
@@ -207,7 +205,7 @@ func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {
207205
h := uint16(device.MessageTransportOffsetReceiver)
208206
// assume the correct msg type is in just the first byte:
209207
// github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
210-
defaultType := uint8(pkt[0])
208+
defaultType := binary.LittleEndian.Uint32(pkt[:h])
211209

212210
var discard uint16 = 0
213211
var possibleType uint32 = 0
@@ -236,7 +234,7 @@ func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {
236234
log.W("wg: %s: amnezia: strip: mismatched msg type %d != %d", a.id, obsType, possibleType)
237235
} // else: nothing to discard
238236

239-
return pkt, uint32(defaultType)
237+
return pkt, defaultType
240238
}
241239

242240
func (a *Amnezia) logIfNeeded(dir string, typ uint32, n int, newn int) {

intra/ipn/wg/wgconn.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,16 +342,22 @@ func (s *StdNetBind) makeReceiveFn(uc net.PacketConn) conn.ReceiveFunc {
342342
recvOverwritten := false
343343

344344
numMsgs := 0
345-
b := bufs[0]
345+
b := bufs[0] // usually sized device.MaxMessageSize
346346

347347
extend(uc, wgtimeout)
348348
n, addr, err := uc.ReadFrom(b)
349349
if err == nil {
350-
recvOverwritten = s.amnezia.Load().recv(&b)
350+
b, recvOverwritten = s.amnezia.Load().recv(b, n)
351351
numMsgs++
352+
if recvOverwritten {
353+
n = len(b)
354+
}
352355
}
353356

354357
for i := range numMsgs {
358+
if recvOverwritten {
359+
copy(bufs[i], b)
360+
}
355361
sizes[i] = n
356362
eps[i] = s.asEndpoint(addr)
357363
}

0 commit comments

Comments
 (0)