@@ -107,23 +107,19 @@ func (a *Amnezia) send(pktptr *[]byte) (ok bool) {
107107 return true
108108}
109109
110- func (a * Amnezia ) recv (pktptr * []byte ) (ok bool ) {
110+ func (a * Amnezia ) recv (pkt []byte , upto int ) (out [] byte , ok bool ) {
111111 if a == nil || ! a .Set () {
112112 return
113113 }
114-
115- var typ uint32
116- pkt := * pktptr
117-
118- recvLen := len (pkt )
119- if recvLen < device .MinMessageSize {
114+ if upto < device .MinMessageSize {
120115 return
121116 }
122- // h := uint16(device.MessageTransportOffsetReceiver)
123117
124- pkt , typ = a .strip (pkt )
118+ var typ uint32
119+ // h := uint16(device.MessageTransportOffsetReceiver)
120+ pkt , typ = a .strip (pkt [:upto ])
121+ strippedSz := len (pkt )
125122
126- stripLen := len (pkt )
127123 switch typ {
128124 case device .MessageInitiationType , a .H1 :
129125 typ = device .MessageInitiationType
@@ -137,12 +133,14 @@ func (a *Amnezia) recv(pktptr *[]byte) (ok bool) {
137133 case device .MessageTransportType , a .H4 : // must be default?
138134 typ = device .MessageTransportType
139135 binary .LittleEndian .PutUint32 (pkt , device .MessageTransportType )
136+ default :
137+ log .W ("wg: %s: amnezia: recv: unexpected type %d" , a .id , typ )
138+ // TODO: error?
140139 }
141140
142- a .logIfNeeded ("recv" , typ , recvLen , stripLen )
141+ a .logIfNeeded ("recv" , typ , strippedSz , upto )
143142
144- * pktptr = pkt
145- return true
143+ return pkt , true
146144}
147145
148146func (a * Amnezia ) instate (pkt []byte ) ([]byte , uint32 ) {
@@ -207,7 +205,7 @@ func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {
207205 h := uint16 (device .MessageTransportOffsetReceiver )
208206 // assume the correct msg type is in just the first byte:
209207 // github.com/WireGuard/wireguard-go/blob/12269c2761/device/noise-protocol.go#L56
210- defaultType := uint8 (pkt [0 ])
208+ defaultType := binary . LittleEndian . Uint32 (pkt [: h ])
211209
212210 var discard uint16 = 0
213211 var possibleType uint32 = 0
@@ -236,7 +234,7 @@ func (a *Amnezia) strip(pkt []byte) ([]byte, uint32) {
236234 log .W ("wg: %s: amnezia: strip: mismatched msg type %d != %d" , a .id , obsType , possibleType )
237235 } // else: nothing to discard
238236
239- return pkt , uint32 ( defaultType )
237+ return pkt , defaultType
240238}
241239
242240func (a * Amnezia ) logIfNeeded (dir string , typ uint32 , n int , newn int ) {
0 commit comments